tensorflow的基本用法(七)——使用MNIST訓練神經(jīng)網(wǎng)絡

文章作者:Tyan
博客:noahsnail.com ?|? CSDN ?|? 簡書

本文主要是使用tensorflow和mnist數(shù)據(jù)集來訓練神經(jīng)網(wǎng)絡。

#!/usr/bin/env python
# _*_ coding: utf-8 _*_

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 下載mnist數(shù)據(jù)
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)


# 定義神經(jīng)網(wǎng)絡模型的評估部分
def compute_accuracy(W, b):
    # 定義測試數(shù)據(jù)的placeholder
    x = tf.placeholder(tf.float32, [None, 784])
    # 定義測試數(shù)據(jù)的真實標簽的placeholder
    y_ = tf.placeholder(tf.float32, [None, 10])
    # 定義預測值
    y = tf.nn.softmax(tf.matmul(x, W) + b)
    # 判斷預測值y和真實值y_中最大數(shù)的索引是否一致,y的值為1-10概率
    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
    # 計算準確率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    # 輸入測試數(shù)據(jù),執(zhí)行準確率的計算并返回
    return sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

# 定義神經(jīng)網(wǎng)絡模型的訓練部分
# 下面定義的神經(jīng)網(wǎng)絡只有一層W*x+b
# 定義輸入數(shù)據(jù)placeholder,不定義輸入樣本的數(shù)目——None,但定義每個樣本的大小為784
x = tf.placeholder(tf.float32, [None, 784])
# 定義神經(jīng)網(wǎng)絡層的權重參數(shù)
W = tf.Variable(tf.zeros([784, 10]))
# 定義神經(jīng)網(wǎng)絡層的偏置參數(shù)
b = tf.Variable(tf.zeros([10]))
# 定義一層神經(jīng)網(wǎng)絡運算,激活函數(shù)為softmax
y = tf.nn.softmax(tf.matmul(x, W) + b)
# 定義訓練數(shù)據(jù)真實標簽的placeholder
y_ = tf.placeholder(tf.float32, [None, 10])
# 定義損失函數(shù),損失函數(shù)為交叉熵,reduction_indices表示沿著tensor的哪個緯度來求和
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
# 定義神經(jīng)網(wǎng)絡的訓練步驟,使用的是梯度下降法,學習率為0.5
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# 初始化所有變量
init = tf.global_variables_initializer()
# 定義Session
sess = tf.Session()
# 執(zhí)行變量的初始化
sess.run(init)
# 迭代進行訓練
for i in range(1000):
    # 取出mnist數(shù)據(jù)集中的100個數(shù)據(jù)
    batch_xs, batch_ys = mnist.train.next_batch(100)
    # 執(zhí)行訓練過程并傳入真實數(shù)據(jù)x, y_
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    if i % 100 == 0:
        print compute_accuracy(W, b) 

執(zhí)行結果如下:

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0.4075
0.8948
0.9031
0.9074
0.9037
0.9125
0.9158
0.912
0.9181
0.9169
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
【社區(qū)內容提示】社區(qū)部分內容疑似由AI輔助生成,瀏覽時請結合常識與多方信息審慎甄別。
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

相關閱讀更多精彩內容

友情鏈接更多精彩內容