LeNet-5實現MNIST分類


LeNet-5實現MNIST分類

本人水平有限,如有錯誤,歡迎指出!

1. LeNet-5

1.1 簡介

​ LeNet-5是由“深度學習三巨頭”之一、圖靈獎得主Yann LeCun在一篇名為"Gradient-Based Learning Applied to Document Recognition"的paper(paper下載地址:https://www.researchgate.net/publication/2985446_Gradient-Based_Learning_Applied_to_Document_Recognition )中提出的神經網絡結構,在手寫數字和機器打印字符上十分高效。

1.2 網絡結構

上圖為原文的網絡結構,但是由於MNIST數據集的圖像為28 * 28(單通道),所以需要對網絡結構進行輕微的調整

本題采用的網絡結構:

圖片輸入:28 * 28 * 1

卷積層:使用6個3 * 3 * 1的過濾器,步長為1,padding為same,輸出的圖像為28 * 28 * 6

最大池化層:使用2 * 2的過濾器,步長為2,輸出的圖像為14 * 14 * 6

卷積層:使用16個3 * 3 * 6的過濾器,步長為1,padding為valid,輸出的圖像為12 * 12 * 16

最大池化層:使用2 * 2的過濾器,步長為2,輸出的圖像為6 * 6 * 16

全連接層:120個節點

全連接層:84個節點

輸出層:10個節點

2. Tensorflow2實現LeNet-5

2.1 數據預處理

​ 首先讀取數據集(建議直接從網上找資源下載然后保存好,不用反復在線讀取),並將灰度值縮小到0到1,便於訓練。同時,要注意將train_data格式從[60000, 28, 28]變為[60000, 28, 28, 1],為后面的卷積運算作准備。

(train_data, train_label), (test_data, test_label) = tf.keras.datasets.mnist.load_data()
train_data = np.expand_dims(train_data.astype(np.float32) / 255.0, axis=-1)
train_label = train_label.astype(np.int32)
test_data = np.expand_dims(test_data.astype(np.float32) / 255.0, axis=-1)
test_label = test_label.astype(np.int32)

2.2 網絡搭建

​ 根據修改后的LeNet-5網絡結構搭建神經網絡,通過繼承tf.keras.Model這個類來定義模型 ,並添加了BN層。

class LeNet5(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters=6, kernel_size=[3, 3], strides=1, padding='same')
        self.pool1 = tf.keras.layers.MaxPooling2D(pool_size=[2, 2], strides=2)
        self.conv2 = tf.keras.layers.Conv2D(filters=16, kernel_size=[3, 3], strides=1, padding='valid')
        self.pool2 = tf.keras.layers.MaxPooling2D(pool_size=[2, 2], strides=2)
        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(units=120, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(units=84, activation=tf.nn.relu)
        self.dense3 = tf.keras.layers.Dense(units=10, activation=tf.nn.softmax)
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.bn2 = tf.keras.layers.BatchNormalization()

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.pool1(x)
        x = tf.nn.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.pool2(x)
        x = tf.nn.relu(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense3(x)
        return x

2.3 模型裝配

​ 在本模型中采用Adam優化算法,初始的學習率為1e-3,由於label采用的是數字編碼,所以使用sparse_categorical_crossentropy。

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    metrics=[tf.keras.metrics.sparse_categorical_accuracy]
)

2.4 模型訓練

​ 在模型訓練的過程中,每128組數據為1個batch,訓練20次,並選出6000組數據作為驗證集,剩下的數據作為訓練集。

​ 在本模型中采取了學習率衰減機制,如果連續3次訓練驗證集分類的准確率沒有提高,學習率就變為原先的0.2倍。同時,為了防止過擬合,模型中還采用了EarlyStopping機制,在連續6次訓練時,如果驗證集分類的准確率沒有提高,就終止訓練。

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_sparse_categorical_accuracy', factor=0.2, patience=3)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=6)
history = model.fit(train_data, train_label, epochs=20, batch_size=128, verbose=2, validation_split=0.1, callbacks=[reduce_lr, early_stopping])

2.5 測試效果

​ 經過訓練,MNIST測試集的分類准確率可達到99%以上,訓練集與驗證集的分類准確率變化過程和代碼運行信息如下所示,完整代碼可見https://github.com/NickHan-cs/Tensorflow2.x。

Epoch 1/20
422/422 - 2s - loss: 0.2597 - sparse_categorical_accuracy: 0.9201 - val_loss: 0.2141 - val_sparse_categorical_accuracy: 0.9300 - lr: 0.0010
Epoch 2/20
422/422 - 2s - loss: 0.0704 - sparse_categorical_accuracy: 0.9779 - val_loss: 0.0550 - val_sparse_categorical_accuracy: 0.9825 - lr: 0.0010
Epoch 3/20
422/422 - 2s - loss: 0.0507 - sparse_categorical_accuracy: 0.9841 - val_loss: 0.0576 - val_sparse_categorical_accuracy: 0.9823 - lr: 0.0010
Epoch 4/20
422/422 - 2s - loss: 0.0410 - sparse_categorical_accuracy: 0.9867 - val_loss: 0.0505 - val_sparse_categorical_accuracy: 0.9838 - lr: 0.0010
Epoch 5/20
422/422 - 2s - loss: 0.0314 - sparse_categorical_accuracy: 0.9897 - val_loss: 0.0513 - val_sparse_categorical_accuracy: 0.9852 - lr: 0.0010
Epoch 6/20
422/422 - 2s - loss: 0.0273 - sparse_categorical_accuracy: 0.9913 - val_loss: 0.0472 - val_sparse_categorical_accuracy: 0.9875 - lr: 0.0010
Epoch 7/20
422/422 - 2s - loss: 0.0269 - sparse_categorical_accuracy: 0.9909 - val_loss: 0.0453 - val_sparse_categorical_accuracy: 0.9872 - lr: 0.0010
Epoch 8/20
422/422 - 2s - loss: 0.0191 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.0465 - val_sparse_categorical_accuracy: 0.9885 - lr: 0.0010
Epoch 9/20
422/422 - 2s - loss: 0.0172 - sparse_categorical_accuracy: 0.9944 - val_loss: 0.0549 - val_sparse_categorical_accuracy: 0.9863 - lr: 0.0010
Epoch 10/20
422/422 - 2s - loss: 0.0157 - sparse_categorical_accuracy: 0.9948 - val_loss: 0.0466 - val_sparse_categorical_accuracy: 0.9882 - lr: 0.0010
Epoch 11/20
422/422 - 2s - loss: 0.0126 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.0616 - val_sparse_categorical_accuracy: 0.9870 - lr: 0.0010
Epoch 12/20
422/422 - 2s - loss: 0.0044 - sparse_categorical_accuracy: 0.9988 - val_loss: 0.0412 - val_sparse_categorical_accuracy: 0.9902 - lr: 2.0000e-04
Epoch 13/20
422/422 - 2s - loss: 0.0027 - sparse_categorical_accuracy: 0.9995 - val_loss: 0.0438 - val_sparse_categorical_accuracy: 0.9895 - lr: 2.0000e-04
Epoch 14/20
422/422 - 2s - loss: 0.0021 - sparse_categorical_accuracy: 0.9997 - val_loss: 0.0441 - val_sparse_categorical_accuracy: 0.9893 - lr: 2.0000e-04
Epoch 15/20
422/422 - 2s - loss: 0.0019 - sparse_categorical_accuracy: 0.9997 - val_loss: 0.0451 - val_sparse_categorical_accuracy: 0.9902 - lr: 2.0000e-04
Epoch 16/20
422/422 - 2s - loss: 0.0013 - sparse_categorical_accuracy: 0.9999 - val_loss: 0.0447 - val_sparse_categorical_accuracy: 0.9893 - lr: 4.0000e-05
Epoch 17/20
422/422 - 2s - loss: 0.0013 - sparse_categorical_accuracy: 0.9999 - val_loss: 0.0445 - val_sparse_categorical_accuracy: 0.9895 - lr: 4.0000e-05
Epoch 18/20
422/422 - 2s - loss: 0.0012 - sparse_categorical_accuracy: 0.9999 - val_loss: 0.0444 - val_sparse_categorical_accuracy: 0.9898 - lr: 4.0000e-05
313/313 - 1s - loss: 0.0363 - sparse_categorical_accuracy: 0.9902


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM