一:數據集
采用MNIST數據集:--》官網
數據集被分成兩部分:60000行的訓練數據集和10000行的測試數據集。
其中每一張圖片包含28*28個像素,我們把這個數組展開成一個向量,長度為28*28=784.在MNIST訓練數據集中mnist.train.images是一個形狀為[60000,784]的張量,第一個維度數字用來索引圖片,第二個維度數字用來索引每張圖片中的像素點。圖片里的某個像素的強度值介於0-1之間。

MNIST數據集的標簽是介於0-9的數字,我們把便簽轉化為‘one-hot vectors’.一個one-hot向量除了某一位數字1以外,其余維度數字都是0.比如標簽0將表示為([1,0,0,0,0,0,0,0,0,0,0]),標簽3表示為([0,0,0,1,0,0,0,0,0,0]).所以標簽相當於[60000,10]的數字矩陣。
我們的結果是0-9,我們的模型可能推測出一張圖片是數字9的概率為80%,是數字8的概率為10%,然后其他數字的概率更小,總體概率加起來等於1.這相當於一個使用softmax回歸模型的案例。

下面使用softmax模型來預測:
# MNIST數據集 手寫數字 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 載入數據集,如果沒有下載,程序會自動下載 mnist=input_data.read_data_sets('MNIST_data',one_hot=True) # 每個批次的大小 batch_size=100 # 計算一共有多少個批次 n_batch=mnist.train.num_examples//batch_size # 定義兩個placeholder x=tf.placeholder(tf.float32,[None,784]) y=tf.placeholder(tf.float32,[None,10]) # 創建一個簡單的神經網絡 W=tf.Variable(tf.zeros([784,10])) b=tf.Variable(tf.zeros([10])) prediction=tf.nn.softmax(tf.matmul(x,W)+b) # 二次代價函數 loss=tf.reduce_mean(tf.square(y-prediction)) # 使用梯度下降法 train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化變量 init=tf.global_variables_initializer() # 求最大值在哪個位置,結果存放在一個布爾值列表中 correct_prediction=tf.equal(tf.argmax(y,1),tf.arg_max(prediction,1))# argmax返回一維張量中最大值所在的位置 # 求准確率 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) # cast作用是將布爾值轉換為浮點型。 with tf.Session() as sess: sess.run(init) for epoch in range(21): # 訓練20次 for batch in range(n_batch): # 每次喂入一定的數據 batch_xs,batch_ys=mnist.train.next_batch(batch_size) sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) #求准確率 acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) print('Iter:'+str(epoch)+',Testing Accuracy:'+str(acc))
# 結果 # 可以看出每次訓練准確率都在提高 Iter:0,Testing Accuracy:0.8301 Iter:1,Testing Accuracy:0.8706 Iter:2,Testing Accuracy:0.8811 Iter:3,Testing Accuracy:0.8883 Iter:4,Testing Accuracy:0.8943 Iter:5,Testing Accuracy:0.8966 Iter:6,Testing Accuracy:0.9002 Iter:7,Testing Accuracy:0.9017 Iter:8,Testing Accuracy:0.9043 Iter:9,Testing Accuracy:0.9052 Iter:10,Testing Accuracy:0.9061 Iter:11,Testing Accuracy:0.9071 Iter:12,Testing Accuracy:0.908 Iter:13,Testing Accuracy:0.9096 Iter:14,Testing Accuracy:0.9094 Iter:15,Testing Accuracy:0.9102 Iter:16,Testing Accuracy:0.9116 Iter:17,Testing Accuracy:0.9119 Iter:18,Testing Accuracy:0.9126 Iter:19,Testing Accuracy:0.9134 Iter:20,Testing Accuracy:0.9136
