Tensorflow之MNIST手寫數字識別:分類問題(2)


整體代碼:

#數據讀取
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)

#定義待輸入數據的占位符
#mnist中每張照片共有28*28=784個像素點
x = tf.placeholder(tf.float32,[None,784],name="X")

#0-9一共10個數字=>10個類別
y = tf.placeholder(tf.float32,[None,10],name="Y")

#定義模型變量
#以正態分布的隨機數初始化權重W,以常數0初始化偏置b
#在神經網絡中,權值W的初始值通常設為正態分布的隨機數,偏置項b的初始值通常也設置為正態分布的隨機數或常數。
W = tf.Variable(tf.random_normal([784,10],name="W"))
b = tf.Variable(tf.zeros([10]),name="b")

#用單個神經元構建神經網絡
forward=tf.matmul(x,W) + b   #前向計算

#結果分類
#當我們處理多分類任務的時候,通常需要使用Softmax Regression模型。Softmax會對每一類別估算出一個概率。
#工作原理:將判定為某一類的特征相加,然后將這些特征轉化為判定是這一類的概率
pred = tf.nn.softmax(forward)     #Softmax分類

#設置訓練參數
train_epochs = 120     #訓練輪數
batch_size = 120      #單次訓練樣本數(批次大小)
total_batch = int(mnist.train.num_examples/batch_size)              #一輪訓練有多少批次
display_step = 1   #顯示粒度
learning_rate = 0.01             #學習率 

#概率估算值需要將預測輸出值控制在[0,1]區間內。二元分類問題的目標是正確預測兩個可能標簽中的一個
#邏輯回歸可以用於處理這類問題。二元邏輯回歸的損失函數一般采用對數損失函數
#多元分類:邏輯回歸可生成介於0到1.0之間的小數。Softmax將這一想法延伸到多類別領域。
#在多類別問題中,Softmax會為每個類別分配一個用小數表示的概率。這些用小數表示的概率相加之和必須是1.0

#交叉熵損失函數:交叉熵是一個信息論的概念,它原來是用來估算平均編碼長度的。
#交叉熵刻畫的是兩個概率分布之間的距離,p代表正確答案,q代表的預測值,交叉熵越小,兩個概率的分布越接近
#定義損失函數
loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))    #交叉熵

#選擇優化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)     #梯度下降優化器

#定義准確率
# 檢查預測類別tf.argmax(pred,1)與實際類別tf.argmax(y,1)的匹配情況
#argmax()將數組中最大值的下標取出來
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))

#准確率,將布爾值轉化為浮點數,並計算平均值    tf.cast()將布爾值投射成浮點數
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

#聲明會話,初始化變量
sess = tf.Session()
init = tf.global_variables_initializer()   #變量初始化
sess.run(init)

#訓練模型
for epoch in range(train_epochs):
    for batch in range(total_batch):
        xs,ys = mnist.train.next_batch(batch_size)  #讀取批次數據
        sess.run(optimizer,feed_dict={x:xs,y:ys})   #執行批次訓練
        
    #total_batch個批次訓練完成后,使用驗證數據計算誤差與准確率,驗證集沒有分批
    loss,acc = sess.run([loss_function,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
    
    #打印訓練過程中的詳細信息
    if (epoch+1) % display_step == 0:
        print("Train Epoch:",'%02d'%(epoch+1),"Loss=","{:.9f}".format(loss),"Accuracy=","{:.4f}".format(acc))
        
print("Train Finished!")
        
#評估模型
#完成訓練后,在測試集上評估模型的准確率
accu_test = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Test Accuracy:",accu_test)
#完成訓練后,在驗證集上評估模型的准確率
accu_validation = sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
print("Test Accuracy:",accu_validation)
#完成訓練后,在訓練集上評估模型的准確率
accu_train = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
print("Test Accuracy:",accu_train)

#應用模型
#在建立模型並進行訓練后,若認為准確率可以接受,則可以使用此模型進行預測
#由於pred預測結果是one_hot編碼格式,所以需要轉換成0~9數字
prediction_result = sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})

#查看預測結果中的前10項
prediction_result[0:10]

#定義可視化函數
def plot_images_labels_prediction(images,labels,prediction,index,num=10):  #參數: 圖形列表,標簽列表,預測值列表,從第index個開始顯示,缺省一次顯示10幅
    fig = plt.gcf()             #獲取當前圖表,Get Current Figure
    fig.set_size_inches(10,12)    #1英寸等於2.45cm
    if num > 25 :      #最多顯示25個子圖
        num = 25
    for i in range(0,num):
        ax = plt.subplot(5,5,i+1)   #獲取當前要處理的子圖
        ax.imshow(np.reshape(images[index],(28,28)), cmap = 'binary')              #顯示第index個圖像
        title = "labels="+str(np.argmax(labels[index]))              #構建該圖上要顯示的title信息
        if len(prediction)>0:
            title += ",predict="+str(prediction[index])
            
        ax.set_title(title,fontsize=10)    #顯示圖上的title信息
        ax.set_xticks([])           #不顯示坐標軸
        ax.set_yticks([])
        index += 1
    plt.show()
#可視化預測結果
plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,10,10)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM