AlexNet網絡解決mnist手寫數字識別問題


代碼(Tensorflow2.0):

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, datasets, Sequential, metrics, optimizers
import os
import numpy as np

tf.random.set_seed(0)
np.random.seed(0)
os.environ['TF_CPP_MIN_LEVEL'] = '2'
assert tf.__version__.startswith('2.')

batch_size = 128
optimizer = optimizers.Adam(0.000005)
epochs = 2

def preprocess(x, y):
x = tf.cast(x, dtype=tf.float32)/255.
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x, y


(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = x_train[:, :, :, np.newaxis] # 增加一個維度
x_test = x_test[:, :, :, np.newaxis]
print('trian_shape:', x_train.shape, y_train.shape)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.map(preprocess).shuffle(50000).batch(batch_size)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).batch(batch_size)

class AlexNet(keras.Model):
def __init__(self):
super(AlexNet, self).__init__()


self.conv = Sequential([
# unit1 [b,28,28,1] => [b,14,14,16]
layers.Conv2D(16, (3, 3), padding='same', strides=1, activation=tf.nn.relu),
layers.MaxPool2D(pool_size=(2, 2), strides=2, padding='same'),
layers.BatchNormalization(), # 使用bn層代替LRN

# unit2 [b,14,14,16] => [b,7,7,32]
layers.Conv2D(32, (3, 3), padding='same', strides=1, activation=tf.nn.relu),
layers.MaxPool2D(pool_size=(2, 2), strides=2, padding='same'),
layers.BatchNormalization(),

# unit3 [b,7,7,32] => [b,7,7,64]
layers.Conv2D(64, (3, 3), padding='same', strides=1, activation=tf.nn.relu),

# unit4 [b,7,7,64] => [b,7,7,128]
layers.Conv2D(128, (3, 3), padding='same', strides=1, activation=tf.nn.relu),

# unit5 [b,7,7,128] => [b,4,4,256]
layers.Conv2D(256, (3, 3), padding='same', strides=1, activation=tf.nn.relu),
layers.MaxPool2D(pool_size=(2, 2), strides=2, padding='same'),
layers.BatchNormalization(),

])

self.fc = Sequential([
# fc1
layers.Dense(4096, activation=tf.nn.relu),
layers.Dropout(0.4),
# fc2
layers.Dense(2048, activation=tf.nn.relu),
layers.Dropout(0.4),
# fc3
layers.Dense(1024, activation=tf.nn.relu),
layers.Dropout(0.4),
# fc4
layers.Dense(10, activation=tf.nn.relu)
])

def call(self, inputs, training=None):
x = inputs
out = self.conv(x)
out = tf.reshape(out, (-1, 4*4*256))
out = self.fc(out)
return out

model = AlexNet()

# 測試網絡模型輸出shape
# sample = tf.random.normal((1, 28, 28, 1))
# out = model(sample)
# print('out_shape:', out.shape)

# 輸出網絡模型結構
# model.build(input_shape=(None, 28, 28, 1))
# model.summary()


def main():
# for epoch in range(epochs):
# for step, (x, y) in enumerate(train_db):
# with tf.GradientTape() as tape:
# logits = model(x)
# y_onehot = tf.one_hot(y, depth=10)
# loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
# loss = tf.reduce_mean(loss)
# grads = tape.gradient(loss, model.trainable_variables)
# optimizer.apply_gradients(zip(grads, model.trainable_variables))
#
# if step % 100 == 0:
# print('epoch:', epoch, 'step:',step, 'loss:', float(loss))
#
# # test
# total_correct = 0
# total_num = 0
# for step, (x, y) in enumerate(test_db):
# logits = model(x)
# prob = tf.nn.softmax(logits, axis=1)
# pred = tf.cast(tf.argmax(prob, axis=1), dtype=tf.int32)
# correct = tf.reduce_sum(tf.cast(tf.equal(pred, y), dtype=tf.int32))
#
# total_correct += correct
# total_num += x.shape[0]
# acc = total_correct/total_num
# print('epoch:', epoch, 'acc:', float(acc))

# 簡寫形式 需要 y=tf.one_hot(y,depth=10)
model.compile(optimizer=optimizers.Adam(lr=0.0001),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['acc'])
model.fit(train_db, epochs=epochs, validation_data=test_db, validation_freq=2)
model.save_weights('./checkpoint/weights.ckpt')
print('save weights')

# evaluate
model.evaluate(test_db)



if __name__ == '__main__':
main()
使用AlexNet神經網絡解決mnist問題確實有點大才小用的意思,因為AlexNet網絡較深,需要訓練的參數比較多,對於mnist分類問題來說淺層網絡就
可以達到不錯的效果,這里使用AlexNet網絡只是作為一個新手來熟悉一下網絡,練個手。准確率的話在測試集上可以達到98%的樣子,還可以。
AlexNet網絡包括5個卷積層和3個全連接層(我這里使用了4個全連接層,是希望輸出可以遞減,不要一下子下降太多),
使用dropout層來防止過擬合,參數為0.4。在訓練過程中,損失函數、learning_rate等參數的選擇比較重要,選用不同參數訓練的時間和准確率也不一樣。
測試集准確率:

AlexNet網絡結構:

  接下來我將使用AlexNet網絡解決一個不同花種分類的問題,希望可以得到不錯的結果。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM