MNIST手寫字母識別(一)


這是個分類應用入門:使用softmax分類,簡單來說就是概論轉化為0-1區間的一個數字

讀取數據集

1 # 導入相關庫
2 import tensorflow as tf
3 from tensorflow.examples.tutorials.mnist import input_data
4 mnist=input_data.read_data_sets("D:/MNIST",one_hot=True)

 

 獨熱編碼(one hot encoding)

一種稀疏向量,其中:一個元素設為1,所有其他元素均設為0

獨熱編碼常用於表示擁有有限個可能值的字符串或標識符

工作流程:

  1 import tensorflow as tf
  2 import matplotlib.pyplot as plt
  3 import numpy as np
  4 import tensorflow.examples.tutorials.mnist.input_data as input_data  
  5 mnist=input_data.read_data_sets("MNIST_data",one_hot=True)  #讀取數據
  6 import os  #可加可不加,屏蔽通知消息
  7 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  8 
  9 print('訓練集 train 數量:',mnist.train.num_examples,
 10       ',驗證集 validation 數量:',mnist.validation.num_examples,
 11       ',測試集 test 數量:',mnist.test.num_examples)
 12 
 13 # print('train images shape:',mnist.train.images.shape,
 14 #       'labels shaple:',mnist.train.labels.shape)
 15 
 16 # print((mnist.train.images[0].reshape(28,28)))
 17 # print(len(mnist.train.images[0].shape))
 18 
 19 # def plot_image(image):
 20 #     plt.imshow(image.reshape(28,28))
 21 #     plt.show()
 22 #
 23 # plot_image(mnist.train.images[1])
 24 # plt.imshow(mnist.train.images[20000].reshape(14,56))
 25 # plt.show()
 26 
 27 # print(mnist.train.labels[1])
 28 # print(np.argmax(mnist.train.labels[1]))
 29 # mnist_no_one_hot=input_data.read_data_sets("MNIST_data",one_hot=False)
 30 # print(mnist_no_one_hot.train.labels[0:10])
 31 #
 32 # print('validation images:',mnist.validation.images.shape,'labels:',mnist.validation.labels.shape)
 33 #
 34 # print('test images:',mnist.test.images.shape,'labels:',mnist.test.labels.shape)
 35 #
 36 # batch_images_xs,batch_labels_ys=mnist.train.next_batch(batch_size=10)
 37 # print(mnist.train.labels[0:10])
 38 # # print(batch_labels_ys)
 39 
 40 # mnist中每張圖片共有28*28=784個像素點
 41 x=tf.placeholder(tf.float32,[None,784])
 42 # 0-9一共10個數字->10個類別
 43 y=tf.placeholder(tf.float32,[None,10])
 44 
 45 # 定義模型變量(以正態分布的隨機數初始化權重W,以常數0初始化偏置b)
 46 W=tf.Variable(tf.random_normal([784,10],mean=0.0,stddev=1.0))
 47 b=tf.Variable(tf.zeros([10]))
 48 
 49 # 前向計算
 50 forward=tf.matmul(x,W)+b
 51 #softmax分類
 52 pred=tf.nn.softmax(forward)
 53 # 定義交叉熵損失函數
 54 loss_function=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))
 55 
 56 # 設置訓練參數
 57 train_epochs=150 # 訓練輪數
 58 batch_size=50 # 單次訓練樣本數(批次大小)
 59 total_batch=int(mnist.train.num_examples/batch_size) # 一輪訓練的批次數
 60 display_step=1 # 顯示粒度
 61 learning_rate=0.04 # 學習率
 62 
 63 # 分類模型構建與訓練實踐
 64 #選擇優化器,梯度下降
 65 optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
 66 
 67 # 定義准確率,檢查預測類別tf.argmax(pred,1)與實際類別tf.argmax(y,1)的匹配情況
 68 correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
 69 # 准確率,將布爾值轉化為浮點數,並計算平均值
 70 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
 71 
 72 # 聲明會話
 73 sess=tf.Session()
 74 init=tf.global_variables_initializer()
 75 sess.run(init)
 76 
 77 # 訓練模型
 78 for epoch in range(train_epochs):
 79       for batch in range(total_batch):
 80             xs, ys = mnist.train.next_batch(batch_size)  # 讀取批次數據
 81             sess.run(optimizer, feed_dict={x: xs, y: ys})  # 執行批次訓練
 82 
 83       # total_batch批次訓練完成之后,使用驗證數據計算誤差與准確率,驗證集沒有分批。
 84       loss, acc = sess.run([loss_function, accuracy],feed_dict={x: mnist.validation.images, y: mnist.validation.labels})
 85 
 86       # 打印訓練過程中的詳細信息
 87       if (epoch + 1)% display_step==0:
 88             print("train_epoch:", '%02d' % (epoch + 1), "loss=", "{:.9f}".format(loss),"accuracy=", '{:.4f}'.format(acc))
 89 print("train finished!")
 90 
 91 # 在測試集上評估模型准確率
 92 accu_test=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
 93 print("test accuracy:",accu_test)
 94 
 95 # 在驗證集上評估模型准確率
 96 accu_validation=sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
 97 print("validatin accuracy:",accu_validation)
 98 
 99 # 在訓練集上評估模型准確率
100 accu_train=sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
101 print("tarin accuracy:",accu_train)
102 
103 # 由於pred預測結果是one-hot編碼格式,所以需要轉化為0~9數字
104 prediction_result=sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})
105 
106 # 查看結果中的前十項
107 prediction_result[0:10]
108 # 定義可視化函數
109 
110 def plt_images_labels_prediction(images,  # 圖像列表
111                                   labels,  # 標簽列表
112                                   prediction,  # 預測值列表
113                                   index,  # 從第index個開始顯示
114                                   num=10):  # 缺省依次顯示10副
115       fig = plt.gcf()  # 獲取當前圖表,get current figure
116       fig.set_size_inches(10, 12)  # 1英寸等於2.54cm
117       if num > 25:
118             num = 25  # 最多顯示25個子圖
119       for i in range(0, num):
120             ax = plt.subplot(5, 5, i + 1)  # 獲取當前要處理的子圖
121 
122             ax.imshow(np.reshape(images[index], (28, 28)),
123                       cmap='binary')  # 顯示第index個圖像
124             title = "labels=" + str(np.argmax(labels[index]))  # 構建該圖上要顯示的title信息
125             if len(prediction) > 0:
126                   title += ",predict=" + str(prediction[index])
127 
128             ax.set_title(title)  # 顯示圖上的title
129             ax.set_xticks([])  # 不顯示坐標軸
130             ax.set_yticks([])
131             index += 1
132       plt.show()
133 # 可視化預測結果
134 plt_images_labels_prediction(mnist.test.images,
135                              mnist.test.labels,
136                              prediction_result,10,25)

 

大概就結束了,相當於機器學習的一個helloworld

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM