代码(Tensorflow2.0):
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, datasets, Sequential, metrics, optimizers
import os
import numpy as np
tf.random.set_seed(0)
np.random.seed(0)
os.environ['TF_CPP_MIN_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
batch_size = 128
optimizer = optimizers.Adam(0.000005)
epochs = 2
def preprocess(x, y):
x = tf.cast(x, dtype=tf.float32)/255.
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x, y
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = x_train[:, :, :, np.newaxis] # 增加一个维度
x_test = x_test[:, :, :, np.newaxis]
print('trian_shape:', x_train.shape, y_train.shape)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.map(preprocess).shuffle(50000).batch(batch_size)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).batch(batch_size)
class AlexNet(keras.Model):
def __init__(self):
super(AlexNet, self).__init__()
self.conv = Sequential([
# unit1 [b,28,28,1] => [b,14,14,16]
layers.Conv2D(16, (3, 3), padding='same', strides=1, activation=tf.nn.relu),
layers.MaxPool2D(pool_size=(2, 2), strides=2, padding='same'),
layers.BatchNormalization(), # 使用bn层代替LRN
# unit2 [b,14,14,16] => [b,7,7,32]
layers.Conv2D(32, (3, 3), padding='same', strides=1, activation=tf.nn.relu),
layers.MaxPool2D(pool_size=(2, 2), strides=2, padding='same'),
layers.BatchNormalization(),
# unit3 [b,7,7,32] => [b,7,7,64]
layers.Conv2D(64, (3, 3), padding='same', strides=1, activation=tf.nn.relu),
# unit4 [b,7,7,64] => [b,7,7,128]
layers.Conv2D(128, (3, 3), padding='same', strides=1, activation=tf.nn.relu),
# unit5 [b,7,7,128] => [b,4,4,256]
layers.Conv2D(256, (3, 3), padding='same', strides=1, activation=tf.nn.relu),
layers.MaxPool2D(pool_size=(2, 2), strides=2, padding='same'),
layers.BatchNormalization(),
])
self.fc = Sequential([
# fc1
layers.Dense(4096, activation=tf.nn.relu),
layers.Dropout(0.4),
# fc2
layers.Dense(2048, activation=tf.nn.relu),
layers.Dropout(0.4),
# fc3
layers.Dense(1024, activation=tf.nn.relu),
layers.Dropout(0.4),
# fc4
layers.Dense(10, activation=tf.nn.relu)
])
def call(self, inputs, training=None):
x = inputs
out = self.conv(x)
out = tf.reshape(out, (-1, 4*4*256))
out = self.fc(out)
return out
model = AlexNet()
# 测试网络模型输出shape
# sample = tf.random.normal((1, 28, 28, 1))
# out = model(sample)
# print('out_shape:', out.shape)
# 输出网络模型结构
# model.build(input_shape=(None, 28, 28, 1))
# model.summary()
def main():
# for epoch in range(epochs):
# for step, (x, y) in enumerate(train_db):
# with tf.GradientTape() as tape:
# logits = model(x)
# y_onehot = tf.one_hot(y, depth=10)
# loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
# loss = tf.reduce_mean(loss)
# grads = tape.gradient(loss, model.trainable_variables)
# optimizer.apply_gradients(zip(grads, model.trainable_variables))
#
# if step % 100 == 0:
# print('epoch:', epoch, 'step:',step, 'loss:', float(loss))
#
# # test
# total_correct = 0
# total_num = 0
# for step, (x, y) in enumerate(test_db):
# logits = model(x)
# prob = tf.nn.softmax(logits, axis=1)
# pred = tf.cast(tf.argmax(prob, axis=1), dtype=tf.int32)
# correct = tf.reduce_sum(tf.cast(tf.equal(pred, y), dtype=tf.int32))
#
# total_correct += correct
# total_num += x.shape[0]
# acc = total_correct/total_num
# print('epoch:', epoch, 'acc:', float(acc))
# 简写形式 需要 y=tf.one_hot(y,depth=10)
model.compile(optimizer=optimizers.Adam(lr=0.0001),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['acc'])
model.fit(train_db, epochs=epochs, validation_data=test_db, validation_freq=2)
model.save_weights('./checkpoint/weights.ckpt')
print('save weights')
# evaluate
model.evaluate(test_db)
if __name__ == '__main__':
main()
使用AlexNet神经网络解决mnist问题确实有点大才小用的意思,因为AlexNet网络较深,需要训练的参数比较多,对于mnist分类问题来说浅层网络就
可以达到不错的效果,这里使用AlexNet网络只是作为一个新手来熟悉一下网络,练个手。准确率的话在测试集上可以达到98%的样子,还可以。
AlexNet网络包括5个卷积层和3个全连接层(我这里使用了4个全连接层,是希望输出可以递减,不要一下子下降太多),
使用dropout层来防止过拟合,参数为0.4。在训练过程中,损失函数、learning_rate等参数的选择比较重要,选用不同参数训练的时间和准确率也不一样。
测试集准确率:
AlexNet网络结构:
接下来我将使用AlexNet网络解决一个不同花种分类的问题,希望可以得到不错的结果。