文章作者:Tyan
博客:noahsnail.com ?|? CSDN ?|? 簡書
本文主要是使用tensorflow和mnist數(shù)據(jù)集來訓練神經(jīng)網(wǎng)絡。
#!/usr/bin/env python
# _*_ coding: utf-8 _*_
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 下載mnist數(shù)據(jù)
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# 定義神經(jīng)網(wǎng)絡模型的評估部分
def compute_accuracy(W, b):
# 定義測試數(shù)據(jù)的placeholder
x = tf.placeholder(tf.float32, [None, 784])
# 定義測試數(shù)據(jù)的真實標簽的placeholder
y_ = tf.placeholder(tf.float32, [None, 10])
# 定義預測值
y = tf.nn.softmax(tf.matmul(x, W) + b)
# 判斷預測值y和真實值y_中最大數(shù)的索引是否一致,y的值為1-10概率
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
# 計算準確率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 輸入測試數(shù)據(jù),執(zhí)行準確率的計算并返回
return sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
# 定義神經(jīng)網(wǎng)絡模型的訓練部分
# 下面定義的神經(jīng)網(wǎng)絡只有一層W*x+b
# 定義輸入數(shù)據(jù)placeholder,不定義輸入樣本的數(shù)目——None,但定義每個樣本的大小為784
x = tf.placeholder(tf.float32, [None, 784])
# 定義神經(jīng)網(wǎng)絡層的權重參數(shù)
W = tf.Variable(tf.zeros([784, 10]))
# 定義神經(jīng)網(wǎng)絡層的偏置參數(shù)
b = tf.Variable(tf.zeros([10]))
# 定義一層神經(jīng)網(wǎng)絡運算,激活函數(shù)為softmax
y = tf.nn.softmax(tf.matmul(x, W) + b)
# 定義訓練數(shù)據(jù)真實標簽的placeholder
y_ = tf.placeholder(tf.float32, [None, 10])
# 定義損失函數(shù),損失函數(shù)為交叉熵,reduction_indices表示沿著tensor的哪個緯度來求和
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
# 定義神經(jīng)網(wǎng)絡的訓練步驟,使用的是梯度下降法,學習率為0.5
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# 初始化所有變量
init = tf.global_variables_initializer()
# 定義Session
sess = tf.Session()
# 執(zhí)行變量的初始化
sess.run(init)
# 迭代進行訓練
for i in range(1000):
# 取出mnist數(shù)據(jù)集中的100個數(shù)據(jù)
batch_xs, batch_ys = mnist.train.next_batch(100)
# 執(zhí)行訓練過程并傳入真實數(shù)據(jù)x, y_
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
if i % 100 == 0:
print compute_accuracy(W, b)
執(zhí)行結果如下:
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.4075
0.8948
0.9031
0.9074
0.9037
0.9125
0.9158
0.912
0.9181
0.9169