背景:MNIST 數據集來自美國國家標准與技術研究所,National Institute of Standardsand Technology (NIST).
數據集由來自250個不同人手寫的數字構成,其中50%是高中學生,50%來自人口普查局(the Census Bureau)的工作人員
其中,訓練集55000驗證集5000 測試集10000。MNIST 數據集可在 http://yann.lecun.com/exdb/mnist/ 獲取,也可以不下載因為TensorFlow提供了數據集讀取方法。
一、數據下載和讀取
1.1數據下載
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
MNIST數據集文件在讀取時如果指定目錄下不存在,則會自動去下載,需等待一定時間如果已經存在了,則直接讀取。
1.2數據讀取
import tensorflow.examples.tutorials.mnist.input_data as input_data mnist = input_data.read_data_sets("MNIST_data/",one_hot=True) print('訓練集train 數量:',mnist.train.num_examples,',驗證集 validation 數量:',mnist.validation.num_examples,',測試集test 數量:',mnist.test.num_examples)
print('train images shape:',mnist.train.images.shape,'labels shaple:',mnist.train.labels.shape)
#(55000,784)意思是5w5千條,每一條784位(圖片大小28*28);(55000,10):10分類 One Hot編碼
1.2.1數據的批量讀取
MNIST數據包里提供了 .next_batch方法,可實現批量讀取
1.3可視化image
def plot_image(image): plt.imshow(image.reshape(28,28),cmap='binary') plt.show() plot_image(mnist.train.images[1])
二、標簽數據與獨熱編碼
2.1獨熱編碼簡介
一種稀疏向量,其中:一個元素設為1,所有其他元素均設為0
獨熱編碼常用於表示擁有有限個可能值的字符串或標識符
例如:假設某個植物學數據集記錄了15000個不同的物種,其中每個物種都用獨一無二的字符串標識符來表示。在特征工程過程中,可能需要將這些字符串標識符編碼為獨熱向量,向量的大小為15000
為什么要采用one hot編碼?
- 將離散特征的取值擴展到了歐式空間,離散特征的某個取值就對應歐式空間的某個點
- 機器學習算法中,特征之間距離的計算或相似度的常用計算方法都是基於歐式空間的
- 將離散型特征使用one-hot編碼,會讓特征之間的距離計算更加合理
獨熱編碼如何取值?
非獨熱編碼的標簽值:
三、模型構建
3.1定義待輸入數據的占位符
#mnist 中每張圖片共有28*28=784個像素點
x=tf.placeholder(tf.float32,[None,784],name="X")
#0-9一共10個數字=>10個類別
y=tf.placeholder(tf.float32,[None,10],name="Y")
3.2定義模型變量
在本案例中,以正態分布的隨機數初始化權重W,以常數0初始化偏置b
W= tf.Variable(tf.random_normal([784,10]),name="W")
b= tf.Variable(tf.zeros([10]),name="b")
3.3定義前向計算和結果分類
forward = tf.matmul(x, W) + b #前向計算
當我們處理多分類任務時,通常需要使用Softmax Regression模型。
Softmax Regression會對每一類別估算出一個概率。
工作原理:將判定為某一類的特征相加,然后將這些特征轉化為判定是這一類的概率。
pred = tf.nn.softmax(forward) #Softmax分類
四、訓練模型
4.1設置訓練參數
train_epochs = 50 #訓練輪數 batch_size=100 #單次訓練樣本數(批次大小) total_batch= int(mnist.train.num_examples/batch_size)#一輪訓練有多少批次 display_step=1 #顯示粒度 learning_rate=0.01 #學習率
4.2定義損失函數,選擇優化器
loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) #定義交叉熵損失函數 optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function) #梯度下降優化器
4.3定義准確率
#檢查預測類別tf.argmax(pred,1)與實際類別tf.argmax(y.1)的匹配情況,argmax能把最大值的下標取出來 correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))#相等則返回True #准確率,將布爾值轉化為浮點數,並計算平均值 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) sess = tf.Session()#聲明會活 init = tf.global_variables_initializer()#變量初始化 sess.run(init)
拓展部分:argmax()用法
4.4訓練過程
#開始訓練 for epoch in range(train_epochs): for batch in range(total_batch): xs,ys = mnist.train.next_batch(batch_size) # 讀取批次數據 sess.run(optimizer,feed_dict={x:xs,y:ys}) # 執行批次訓練 #total_batch個批次訓練完成后,使用驗證數據計算課差與准確率;驗證集沒有分批 loss,acc = sess.run([loss_function,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels}) #打印訓練過程中的詳細信息 if (epoch+1) % display_step == 0: print("Train Epoch:",'%02d' %(epoch+1),"Loss=","{:.9f}".format(loss),"Accuracy=","{:.4f}".format(acc)) print("Train Finished!")
運行結果為:
從結果可以看出,損失值Loss是趨於更小的,准確率Accuracy越來越高,可以修改參數,使得准確率(驗證集)到90%以上~~
五、評估模型
完成訓練后,在測試集上評估模型的准確率
accu_test = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) print("Test Accuracy:",accu_test)
Note:測試環節中沒有分批,所有測試數據1萬條直接執行
測試結果86.14%, 上面模型訓練時是由驗證集的數據做出來的,是86.46%。
在訓練集上評估模型的准確率
accu_train = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels}) print("Test Accuracy:",accu_train)
通過數據集的合理划分,訓練和驗證效果差不多,准確率比較滿意則可以進行模型預測。
六、模型應用和可視化
#模型預測 #由於pred預測結果是one-hot編碼格式,所以需要轉換為0-9數字 prediction_result = sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images}) print(prediction_result[0:10]) #查看預測結果中的前10項
然而不夠直觀,不能與數據集的圖片相對應。
6.1定義可視化函數
1 def plot_images_labels_prediction(images,labels,prediction,index,num=10): #圖像列表,標簽列表,預測值列表,從第index個開始顯示 , 缺省一次顯示10幅 2 fig = plt.gcf() #獲取當前圖表,get current figure 3 fig.set_size_inches(10,12) #1英寸等於2.54cm 4 if num > 25: 5 num = 25 #最多顯示25個子圖 6 for i in range(0, num): 7 ax = plt.subplot(5,5,i+1) #獲取當前要處理的子圖 8 ax.imshow(np.reshape(images[index], (28, 28)),cmap='binary') # 顯示第index個圖像 9 title = "label=" + str(np.argmax(labels[index]))# 構建該圖上要顯示的 10 if len(prediction)>0: 11 title += ",predict="+ str(prediction[index]) 12 ax.set_title(title, fontsize=10) #顯示圖上title信息 13 ax.set_xticks([]) #不顯示坐標軸 14 ax.set_yticks([]) 15 index += 1 16 plt.show()
Note:第10行代碼意思是如果參數預測值列表為空,也是可以的,這樣可以直接用該函數查看訓練值
6.2可視化預測結果
plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,0,15) #0代表下標從0幅開始,15表示最多顯示15幅
運行結果為:
通過可視化,可以更直觀的查看哪些預測對了 哪些預測錯了。
plot_images_labels_prediction(mnist.test.images,mnist.test.labels,[],0,15)
如果參數預測列表設為空,也能執行,只不過不顯示預測標簽的值。這時函數相當於查看訓練集,運行結果為:
完整代碼如下

#Created by:Huang #Time:2019/10/9 0009. import tensorflow as tf import tensorflow.examples.tutorials.mnist.input_data as input_data import matplotlib.pyplot as plt import numpy as np mnist = input_data.read_data_sets("MNIST_data/",one_hot=True) # print('訓練集train 數量:',mnist.train.num_examples,',驗證集 validation 數量:',mnist.validation.num_examples,',測試集test 數量:',mnist.test.num_examples) # print('train images shape:',mnist.train.images.shape,'labels shaple:',mnist.train.labels.shape) #(55000,784)意思是5w5千條,每一條784位(圖片大小28*28);(55000,10):10分類 One Hot編碼 # def plot_image(image): # plt.imshow(image.reshape(28,28),cmap='binary') # plt.show() # plot_image(mnist.train.images[20000]) x = tf.placeholder(tf.float32,[None,784],name="X")#mnist 中每張圖片共有28*28=784個像素點 y = tf.placeholder(tf.float32,[None,10],name="Y")#0-9一共10個數字=>10個類別 W= tf.Variable(tf.random_normal([784,10]),name="W") #定義變量 b= tf.Variable(tf.zeros([10]),name="b") #用單個神經元構建神經網絡 forward = tf.matmul(x, W) + b #前向計算 pred = tf.nn.softmax(forward) #Softmax分類 train_epochs = 50 #訓練輪數 batch_size=100 #單次訓練樣本數(批次大小) total_batch= int(mnist.train.num_examples/batch_size)#一輪訓練有多少批次 display_step=1 #顯示粒度 learning_rate=0.01 #學習率 loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) #定義交叉熵損失函數 optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function) #梯度下降優化器 #檢查預測類別tf.argmax(pred,1)與實際類別tf.argmax(y.1)的匹配情況,argmax能把最大值的下標取出來 correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) #准確率,將布爾值轉化為浮點數,並計算平均值 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) sess = tf.Session()#聲明會活 init = tf.global_variables_initializer()#變量初始化 sess.run(init) #開始訓練 for epoch in range(train_epochs): for batch in range(total_batch): xs,ys = mnist.train.next_batch(batch_size) # 讀取批次數據 sess.run(optimizer,feed_dict={x:xs,y:ys}) # 執行批次訓練 #total_batch個批次訓練完成后,使用驗證數據計算課差與准確率;驗證集沒有分批 loss,acc = sess.run([loss_function,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels}) #打印訓練過程中的詳細信息 if (epoch+1) % display_step == 0: print("Train Epoch:",'%02d' %(epoch+1),"Loss=","{:.9f}".format(loss),"Accuracy=","{:.4f}".format(acc)) print("Train Finished!") accu_test = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) print("Test Accuracy:",accu_test) accu_train = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels}) print("Train Accuracy:",accu_train) #模型預測 #由於pred預測結果是one-hot編碼格式,所以需要轉換為0-9數字 prediction_result = sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images}) print(prediction_result[0:10]) #查看預測結果中的前10項 def plot_images_labels_prediction(images,labels,prediction,index,num=10): #圖像列表,標簽列表,預測值列表,從第index個開始顯示 , 缺省一次顯示10幅 fig = plt.gcf() #獲取當前圖表,get current figure fig.set_size_inches(10,12) #1英寸等於2.54cm if num > 25: num = 25 #最多顯示25個子圖 for i in range(0, num): ax = plt.subplot(5,5,i+1) #獲取當前要處理的子圖 ax.imshow(np.reshape(images[index], (28, 28)),cmap='binary') # 顯示第index個圖像 title = "label=" + str(np.argmax(labels[index]))# 構建該圖上要顯示的 if len(prediction)>0: title += ",predict="+ str(prediction[index]) ax.set_title(title, fontsize=10) #顯示圖上title信息 ax.set_xticks([]) #不顯示坐標軸 ax.set_yticks([]) index += 1 plt.show() plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,0,15) #最多顯示25張 # plot_images_labels_prediction(mnist.test.images,mnist.test.labels,[],0,15) #最多顯示25張