MNIST數據集手寫體識別(CNN實現)


github博客傳送門
csdn博客傳送門

本章所需知識:

  1. 沒有基礎的請觀看深度學習系列視頻
  2. tensorflow
  3. Python基礎

資料下載鏈接:

深度學習基礎網絡模型(mnist手寫體識別數據集)

MNIST數據集手寫體識別(CNN實現)

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data  # 導入下載數據集手寫體
mnist = input_data.read_data_sets('../MNIST_data/', one_hot=True)


class CNNNet:  # 創建一個CNNNet類
    def __init__(self):
        self.x = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1], name='input_x')  # 創建數據占位符
        self.y = tf.placeholder(dtype=tf.float32, shape=[None, 10], name='input_y')  # 創建標簽占位符

        self.w1 = tf.Variable(tf.truncated_normal(shape=[3, 3, 1, 16], dtype=tf.float32, stddev=tf.sqrt(1 / 16), name='w1'))  # 定義 第一層/輸入層/卷積層 w
        self.b1 = tf.Variable(tf.zeros(shape=[16], dtype=tf.float32, name='b1'))  # 定義 第一層/輸入層/卷積層 偏值b

        self.w2 = tf.Variable(tf.truncated_normal(shape=[3, 3, 16, 32], dtype=tf.float32, stddev=tf.sqrt(1 / 32), name='w2'))  # 定義 第二層/卷積層 w
        self.b2 = tf.Variable(tf.zeros(shape=[32], dtype=tf.float32, name='b2'))  # 定義 第二層/卷積層 偏值b

        self.fc_w1 = tf.Variable(tf.truncated_normal(shape=[28 * 28 * 32, 128], dtype=tf.float32, stddev=tf.sqrt(1 / 128), name='fc_w1'))  # 定義 第三層/全鏈接層/ w
        self.fc_b1 = tf.Variable(tf.zeros(shape=[128], dtype=tf.float32, name='fc_b1'))  # 定義 第三層/全鏈接層/ 偏值b

        self.fc_w2 = tf.Variable(tf.truncated_normal(shape=[128, 10], dtype=tf.float32, stddev=tf.sqrt(1 / 10), name='fc_w2'))  # 定義 第四層/全鏈接層/輸出層 w
        self.fc_b2 = tf.Variable(tf.zeros(shape=[10], dtype=tf.float32, name='fc_b2'))  # 定義 第四層/全鏈接層/輸出層 偏值b

	# 前向計算
    def forward(self):
		# 前向計算 第一層/輸入層/卷積層
        self.conv1 = tf.nn.relu(tf.nn.conv2d(self.x, self.w1, strides=[1, 1, 1, 1], padding='SAME', name='conv1') + self.b1)
        # 前向計算 第二層/卷積層
		self.conv2 = tf.nn.relu(tf.nn.conv2d(self.conv1, self.w2, strides=[1, 1, 1, 1], padding='SAME', name='conv2') + self.b2)
        # 將第二層卷積后的數據撐開為 [批次, 數據]
		self.flat = tf.reshape(self.conv2, [-1, 28 * 28 * 32])
		# 前向計算 第三層/全鏈接層
        self.fc1 = tf.nn.relu(tf.matmul(self.flat, self.fc_w1) + self.fc_b1)
        # 前向計算 第四層/全鏈接層/輸出層
		self.fc2 = tf.matmul(self.fc1, self.fc_w2) + self.fc_b2
        # 輸出層 softmax分類
		self.output = tf.nn.softmax(self.fc2)

	# 后向計算
    def backward(self):
		# 定義 softmax交叉熵 求損失
        self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.fc2, labels=self.y))
        # 使用 AdamOptimizer優化器 優化cost 損失函數
		self.opt = tf.train.AdamOptimizer().minimize(self.cost)

	# 計算測試集識別精度
    def acc(self):
		# 將預測值 output 和 標簽值 self.y 進行比較
        self.acc2 = tf.equal(tf.argmax(self.output, 1), tf.argmax(self.y, 1))
		#  最后對比較出來的bool值 轉換為float32類型后 求均值就可以看到滿值為 1的精度顯示
        self.accaracy = tf.reduce_mean(tf.cast(self.acc2, dtype=tf.float32))


if __name__ == '__main__':
    net = CNNNet()  # 啟動tensorflow繪圖的CNNNet
    net.forward()   # 啟動前向計算
    net.backward()  # 啟動后向計算
    net.acc()       # 啟動精度計算
    init = tf.global_variables_initializer()  # 定義初始化tensorflow所有變量操作
    with tf.Session() as sess:  # 創建一個Session會話
        sess.run(init)          # 執行init變量內的初始化所有變量的操作
        for i in range(10000):  # 訓練10000次
            ax, ay = mnist.train.next_batch(100)  # 從mnist數據集中取數據出來 ax接收圖片 ay接收標簽
            ax_batch = ax.reshape([-1, 28, 28, 1])  # 將取出的 圖片數據 reshape成 NHWC 結構
            loss, output, accaracy, _ = sess.run(fetches=[net.cost, net.output, net.accaracy, net.opt], feed_dict={net.x: ax_batch, net.y: ay})  # 將數據喂進CNN網絡
            # print(loss)      # 打印損失
            # print(accaracy)  # 打印訓練精度
            if i % 10 == 0:    # 每訓練10次
                test_ax, test_ay = mnist.test.next_batch(100)  # 則使用測試集對當前網絡進行測試
                test_ax_batch = test_ax.reshape([-1, 28, 28, 1])  # 將取出的 圖片數據 reshape成 NHWC 結構
                test_output = sess.run(net.output, feed_dict={net.x: test_ax_batch})  # 將測試數據喂進網絡 接收一個output值
                test_acc = tf.equal(tf.argmax(test_output, 1), tf.argmax(test_ay, 1))  # 對output值和標簽y值進行求比較運算
                accaracy2 = sess.run(tf.reduce_mean(tf.cast(test_acc, dtype=tf.float32)))  # 求出精度的准確率進行打印
                print(accaracy2)  # 打印當前測試集的精度

最后附上訓練截圖:

CNN


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM