MNIST 手寫數字識別【入門】


1 問題描述

MNIST 數據集來自美國國家標准與技術研究所, National Institute of Standards and Technology (NIST).數據集由來自 250 個不同人手寫的數字構成, 其中 50% 是高中學生, 50% 來自人口 普查局 (the Census Bureau) 的工作人員

2 數據集獲取

2.1   網站獲取: http://yann.lecun.com/exdb/mnist/ 

2.2 TensorFlow提供了數據集讀取方法

#導入TensorFlow
import tensorflow as tf
#導入讀取方法
import tensorflow.examples.tutorials.mnist.input_data as input_data
#讀入數據集
mnist = input_data.read_data_sets("MNIST_data/",one_hot = True)

注:MNIST數據集文件在讀取時如果指定目錄下不存在,則會自動去下載,需等待一定時間如果已經存在了,則直接讀取

3 了解數據集

print("訓練集 train 數量:",mnist.train.num_examples,
      ",驗證集 validation 數量:",mnist.validation.num_examples,
      ",測試集 test 數量:",mnist.test.num_examples
     )
print("train images shape:",mnist.train.images.shape,
     "lables shaple:",mnist.train.labels.shape)

3.1 看具體image的數據

print(len(mnist.train.images[0]), mnist.train.images[0].shape)
mnist.train.images[0]
# image數據再塑性reshape
mnist.train.images[0].reshape(28,28)
#可視化image
#定義函數
import matplotlib.pyplot as plt

def plot_image(image):
    plt.imshow(image.reshape(28,28),cmap = "binary")
    plt.show()
plot_image(mnist.train.images[3424])

3.2 reshape() 函數  

import numpy as np

int_array = np.array([i for i in range(64)])
print(int_array)
int_array.reshape(4,16)
plt.imshow(mnist.train.images[454].reshape(14,56),cmap = "binary")
plt.show()

3.3數據的批量讀取

3.3.1 python切片

print(mnist.train.labels[0:10])

3.3.2 函數讀取

# next_batch () 實現內部會對數據集先做shuffle處理
batch_images_xs,batch_labels_ys = mnist.train.next_batch(batch_size=10)
print(batch_labels_ys)
 

4 標簽數據和獨熱編碼

4.1標簽數據  

#打印image
plot_image(mnist.train.images[1])
# 打印imag對應的標簽
print(mnist.train.labels[1])

4.2 獨熱編碼

 4.2.1 為什么要采用 one hot 編碼

4.2.2如何從獨熱編碼取值? 

import numpy as np
# 打印imag對應的標簽
print(mnist.train.labels[1])
# argmax 返回的是最大數的索引
np.argmax(mnist.train.labels[1])

 

4.3 非one-hot編碼的標簽值

mnist_no_one_hot = input_data.read_data_sets("MNIST_data/",one_hot=False)
print(mnist_no_one_hot.train.labels[0:10])

5 數據集的划分

5.1 第一種划分

訓練集 - 用於訓練模型的子集集

測試集 - 用於測試模型的子集

確保測試集滿足以下兩個條件:
  規模足夠大,可產生具有統計意義的結果
  能代表整個數據集,測試集的特征應該與訓練集的特征相同

5.1.1工作流程

5.1.2 存在的問題

多次重復執行該流程可能導致模型不知不覺地擬合了特定測試集的特性 

5.2 第二種划分

訓練集 - 用於訓練模型的子集集

驗證集 - 用於驗證模型的子集

測試集 - 用於測試模型的子集

5.2.1工作流程

5.3 數據驗證 

#讀取驗證數據
print("驗證圖像:",mnist.validation.images.shape,
     "標簽:",mnist.validation.labels.shape)
#讀取測試數據
print("測試圖像:",mnist.test.images.shape,
     "標簽:",mnist.test.labels.shape)
#讀取訓練數據
print("訓練圖像:",mnist.train.images.shape,
     "標簽:",mnist.train.labels.shape)

6 模型構建

6.1 定義待輸入數據的占位符  

#mnist 中每張圖片共有28*28 = 784個像素點
x = tf.placeholder(tf.float32,[None,784],name="X")
y = tf.placeholder(tf.float32,[None,10],name="Y")

6.2 定義模型變量

以正態分布的隨機數初始化權重W,以常數0初始化偏置b 

#定義變量
W = tf.Variable(tf.random_normal([784,10],name="W"))
b = tf.Variable(tf.zeros([10]),name="b")

6.3 了解 tf.random_normal ()

#生成100個隨機數
norm = tf.random_normal([100])
with tf.Session() as sess:
    norm_data = norm.eval()
#打印前10個隨機數
print(norm_data[:10])

#圖形化打印出來
import matplotlib.pyplot as plt
plt.hist(norm_data)
plt.show()

6.4 定義前向計算

# matmul叉乘,前向計算
forward = tf.matmul(x,W) + b 

6.4.1 結果分類

# Softmax 分類
pred = tf.nn.softmax(forward)

7 邏輯回歸

許多問題的預測結果是一個在連續空間的數值,比如房價預測問題,可以用線性模型來描述:

但也有很多場景需要輸出的是概率估算值,例如:
  • 根據郵件內容判斷是垃圾郵件的可能性
  • 根據醫學影像判斷腫瘤是惡性的可能性
  • 手寫數字分別是 0、1、2、3、4、5、6、7、8、9的可能性(概率)
這時需要將預測輸出值控制在 [0,1]區間內
二元分類問題的目標是正確預測兩個可能的標簽中的一個【結果只有一個】
邏輯回歸(Logistic Regression)可以用於處理這類問題

7.1 Sigmod 函數

邏輯回歸模型如何確保輸出值始終落在 0 和 1 之間。
Sigmod函數(S型函數)生成的輸出值正好具有這些特性,其定義如下:

定義域為全體實數,值域在[0,1]之間

Z值在0點對應的結果為0.5
sigmoid函數連續可微分

7.1.1 特定樣本的邏輯回歸模型的輸出

7.2 邏輯回歸中的損失函數

線性回歸的損失函數是平方損失,如果邏輯回歸的損失函數也為平方損失,則:

 

其中:

將Sigmoid函數帶入上述函數

非凸函數,有多個極小值

如果采用梯度下降法,會容易導致陷入局部最優解中

 

7.2.1 二元邏輯回歸的損失函數采用 “對數損失函數” :

其中:

 

8 多元分類

8.1 Softmax 思想

邏輯回歸可生成介於0-1.0之間的小數

Softmax將這一想法延伸到多類別領域

在多累別問題中,Softmax會為每一個分類分配一個用小數表示的概率,這些小數表示的概率相加之和為1.0

8.2 Softmax實例

8.3神經網絡中的Softmax層

 

8.4 Softmax方程式

8.5 Softmax舉例

8.6 交叉熵損失函數

交叉熵是信息論中的概念,原為估算平均編碼長度的。給定兩個概率分布p和q,通過q來表示p的交叉熵

交叉熵表示的是兩個概率分布之間的距離,p表示正確答案,q表示預測值,交叉熵越小,兩個概率分布越接近

8.7 交叉熵損失函數計算實例

8.8 定義交叉熵損失函數

# 定義損失函數
loss_function = tf.reduce_mean( -tf.reduce_sum (y * tf.log(pred),reduction_indices = 1))

8.9 argmax()用法

#載入數據
import tensorflow as tf
import numpy as np

arr1 = np.array([1,3,2,5,7,0])
arr2 = np.array([[1.0,2,3],[3,2,1],[4,7,2],[8,3,2]])
print(arr1)
print(arr2)

argmax_1 = tf.argmax(arr1)
argmax_20 = tf.argmax(arr2,0) #指定參數為0,按第一維(行)的元素取值,即同列的每一行
argmax_21 = tf.argmax(arr2,1) #指定參數為1,按第二維(列)的元素取值,即同行的每一列
argmax_22 = tf.argmax(arr2,-1)#指定參數為-1,即第最后維的元素取值

with tf.Session() as sess:
    print(argmax_1.eval())
    print(argmax_20.eval())
    print(argmax_21.eval())
    print(argmax_22.eval())

9 分類模型構建與訓練實踐

9.1載入數據

#載入數據
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot = True)

9.2 定義占位符

# mnist 中每張圖片共有28*28=784個像素點
x = tf.placeholder(tf.float32,[None,784],name="X")
# 0 -9 一共10個數字 ,10個類別
y = tf.placeholder(tf.float32,[None,10],name = "Y")

9.3變量定義

#定義變量
W = tf.Variable(tf.random_normal([784,10],name = "W"))
b = tf.Variable(tf.zeros([10]),name="b")  

神經網絡中,權值W的初始值設為正態分布的隨機數,偏置項b的初始值為1 -10的隨機數或常數。

9.4單個神經元構建神經網絡  

#向前計算
forward = tf.matmul(x,W) + b

9.5 softmax 分類

#softmax分類
pred = tf.nn.softmax(forward)
Softmax Regression 會對每一類別估計出一個概率
工作原理:判定為某一類的特征相加,然后將這些特征轉化為判定是這類的概率

9.6 設置訓練參數

train_epochs = 50   #訓練輪數
batch_size = 100  #單次訓練樣本數【批次大小】
total_batch = int(mnist.train.num_examples/batch_size)  #一輪訓練有多少批次
display_step = 1  #顯示粒度
learning_rate = 0.01 #學習率 

9.7 定義損失函數

# 定義損失函數
loss_function = tf.reduce_mean( -tf.reduce_sum (y * tf.log(pred),reduction_indices = 1))

9.8 選擇優化器

# 選擇優化器【梯度下降】
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function) 

9.9 定義准確率

#檢查預測類別tf.argmax(pred,1)與實際類別tf.argmax(y,1)的匹配情況
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
# 准確率,將布爾值轉化為浮點數,計算平均值
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

9.10 會話聲明

#聲明會話
sess = tf.Session()
#變量初始化
init = tf.global_variables_initializer()
sess.run(init)

9.11 模型訓練

#開始訓練
for epoch in range(train_epochs):
    for batch in range(total_batch):
        xs,ys = mnist.train.next_batch(batch_size) #讀取批次數據
        sess.run(optimizer,feed_dict={x:xs,y:ys})  #執行批次訓練
        
    #total_batch 個批次訓練后,使用驗證數據計算誤差與准確性,驗證集沒有分批
    loss,acc = sess.run([loss_function,accuracy],
        feed_dict = {x: mnist.validation.images ,y: mnist.validation.labels})
    
    #打印訓練過程中的詳細信息
    if(epoch+1) % display_step == 0:
        print("Train Epoch:",'%02d' % (epoch+1),"Loss=", "{:.9f}".format(loss),"Accuracy=","{:.4f}".format(acc))
        
print("Train over!")        

結果:損失值Loss是趨於更小的,准確率Accuracy 越來 越高

9.12 測試模型

測試集中評估准確率

accu_test = sess.run(accuracy,feed_dict = {x:mnist.test.images,y:mnist.test.labels})
print("Test Accuracy:",accu_test)

驗證集中評估准確率

accu_validation = sess.run(accuracy,feed_dict = {x:mnist.validation.images,y:mnist.validation.labels})
print("Test Accuracy:",accu_validation)

訓練集中評估准確率 

accu_train = sess.run(accuracy,feed_dict = {x:mnist.train.images,y:mnist.train.labels})
print("Test Accuracy:",accu_train)

10 模型訓練和可視化

10.1 進行預測

 在建立模型並進行訓練后,若認為准確率可以接受,則使用此模型進行預測

# 由於pred 預測結果是one_hot 編碼格式,所以需要轉換成0 — 9數字
prediction_result = sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})

10.2 查看預測結果

#查看預測結果中的前10項
prediction_result[0:15]

10.3 定義可視化函數

#定義可視化函數
import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images, #圖像列表
                                  labels, #標簽列表
                                  prediction, #預測值列表
                                  index, #從第index個開始顯示 
                                  num = 10 ): #缺省一次顯示10幅
    fig = plt.gcf() #獲取當前圖表,
    fig.set_size_inches(10,12) # 一英寸為2.54cm
    if num > 25:
        num = 25   #最多顯示25個子圖
    for i in range (0,num):
        ax = plt.subplot(5,5,i+1)  #獲取當前要處理的子圖
        
        ax.imshow(np.reshape(images[index],(28,28)), #顯示第index個圖像
                 cmap = "binary")
        title = "label=" + str(np.argmax(labels[index])) #構建該圖上要顯示的
        if len(prediction) > 0:
            title += ",predict=" + str(prediction[index])
        ax.set_title(title,fontsize = 10) #顯示圖上的title信息
        ax.set_xticks([]); #不顯示坐標軸
        ax.set_yticks([])
        index += 1
    plt.show()       

10.4可視化顯示 

plot_images_labels_prediction(mnist.test.images,
                             mnist.test.labels,
                             prediction_result,0,10)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM