03使用tensorflow實現(xiàn)邏輯回歸(mnist數(shù)據(jù)集)

利用基本的mlp實現(xiàn)邏輯回歸(mnist數(shù)據(jù)集)

import tensorflow as tf
import numpy
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
rng = numpy.random

# 設(shè)置訓(xùn)練參數(shù)
learning_rate = 0.01
training_epochs = 25
batch_size = 100
display_step = 1

# 加載訓(xùn)練數(shù)據(jù),第一次加載時需要下載數(shù)據(jù)集到下面自行創(chuàng)建的目錄中
# mnist數(shù)據(jù)集中為28x28x1,單通道數(shù)據(jù)集
mnist = input_data.read_data_sets("MNIST_DATA/", one_hot=True)

# 占位節(jié)點,None 表示將每一個批次輸入的圖片個數(shù)
X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])

# 單層網(wǎng)絡(luò),只有10個節(jié)點
W = tf.Variable(tf.zeros([784, 10]), name="weight")
b = tf.Variable(tf.zeros([10]), name="bias")

# 輸出
pred = tf.nn.softmax(tf.matmul(X, W) + b)

# 損失函數(shù)
cost = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(pred), reduction_indices=1))
# 優(yōu)化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

# 初始化參數(shù)
init = tf.global_variables_initializer()

# 進入圖
with tf.Session() as sess:
    # 圖初始化
    sess.run(init)

    # 訓(xùn)練圖
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples / batch_size)

        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)

            _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs, Y: batch_ys})

            avg_cost += c / total_batch

        if (epoch + 1) % display_step == 0:
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))

    print("Optimization Finished!")

    # Test model
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(Y, 1))
    # Calculate accuracy for 3000 examples
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print("Accuracy:", accuracy.eval({X: mnist.test.images[:3000], Y: mnist.test.labels[:3000]}))
# Epoch: 0025 cost= 0.333732169
# Optimization Finished!
# Accuracy: 0.8883333
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容