代码(Tensorflow2.0)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, Sequential, optimizers, metrics
import os
import numpy as np
tf.random.set_seed(0)
np.random.seed(0)
os.environ['TF_CPP_MIN_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
batch_size = 128
optimizer = optimizers.Adam(lr=0.0001)
epochs = 20
(x_train, y_train), (x_test, y_test) = datasets.cifar100.load_data()
print("train_shape:", x_train.shape, y_train.shape)
# [50000,1] => [50000,]
y_train = tf.squeeze(y_train, axis=1)
y_test = tf.squeeze(y_test, axis=1)
print("train_shape:", x_train.shape, y_train.shape)
def preprocess(x, y):
x = tf.cast(x, dtype=tf.float32)/255.
y = tf.cast(y, dtype=tf.int32)
return x, y
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.map(preprocess).shuffle(50000).batch(batch_size)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).batch(batch_size)
class BasicBlock(layers.Layer):
def __init__(self, filter_num, strides=1):
super(BasicBlock, self).__init__()
# unit1
self.conv1 = layers.Conv2D(filters=filter_num, kernel_size=(3, 3), strides=strides, padding='same')
self.bn1 = layers.BatchNormalization()
self.relu = layers.Activation('relu')
# unit2
self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding='same')
self.bn2 = layers.BatchNormalization()
if strides != 1:
self.downsample = Sequential()
self.downsample.add(layers.Conv2D(filter_num, (1, 1), strides=strides))
else:
self.downsample = lambda x: x
def call(self, inputs, training=None):
x = inputs
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
down = self.downsample(x)
out_put = layers.add([out, down])
out_put = tf.nn.relu(out_put)
return out_put
class ResNet(keras.Model):
def __init__(self,layers_dims, num_classes=100):
super(ResNet, self).__init__()
# 预处理层
self.stem = Sequential([layers.Conv2D(64, (3, 3), strides=(1, 1)),
layers.BatchNormalization(),
layers.Activation('relu'),
layers.MaxPool2D((2, 2), strides=(1, 1), padding='same')])
# 接上4个ResBlock层
self.resblock1 = self.ResBlock(64, blocks=layers_dims[0])
self.resblock2 = self.ResBlock(128, blocks=layers_dims[1], strides=2)
self.resblock3 = self.ResBlock(256, blocks=layers_dims[2], strides=2)
self.resblock4 = self.ResBlock(512, blocks=layers_dims[3], strides=2)
# 分类层
self.avgpool = layers.GlobalAveragePooling2D()
self.fc = layers.Dense(num_classes)
def call(self, inputs, training=None):
out = self.stem(inputs)
out = self.resblock1(out)
out = self.resblock2(out)
out = self.resblock3(out)
out = self.resblock4(out)
out = self.avgpool(out)
out = self.fc(out)
return out
def ResBlock(self, filter_num, blocks, strides=1):
resblock = Sequential()
resblock.add(BasicBlock(filter_num, strides))
for _ in range(blocks):
resblock.add(BasicBlock(filter_num, strides=1))
return resblock
# ResNet-18 18层卷积层 1+4*2*2+1 一个ResBlock包含2个BasicBlock,一个BasicBlock包含2个卷积层
resnet_18 = ResNet([2, 2, 2, 2])
# 测试网络输出shape
# x = tf.random.normal((1, 32, 32, 3))
# out = resnet_18(x)
# print(out.shape)
# 输出网络结构
resnet_18.build(input_shape=(None, 32, 32, 3))
resnet_18.summary()
def main():
for epoch in range(epochs):
for step, (x, y) in enumerate(train_db):
with tf.GradientTape() as tape:
logits = resnet_18(x)
y_onehot = tf.one_hot(y, depth=100)
# tf.losses.categorical_crossentropy 先是正确值 后是预测值 否则loss优化会出问题
loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
loss = tf.reduce_mean(loss)
grads = tape.gradient(loss, resnet_18.trainable_variables)
optimizer.apply_gradients(zip(grads, resnet_18.trainable_variables))
if step % 10 == 0:
print(epoch, step, 'loss:', float(loss))
if step % 50 == 0:
total_correct = 0
total_num = 0
for step, (x, y) in enumerate(test_db):
logits = resnet_18(x)
prob = tf.nn.softmax(logits, axis=1)
pred = tf.cast(tf.argmax(prob, axis=1), dtype=tf.int32)
correct = tf.reduce_sum(tf.cast(tf.equal(pred, y), dtype=tf.int32))
total_correct += correct
total_num += x.shape[0]
acc = total_correct/total_num
print(epoch, step, 'acc:', float(acc))
resnet_18.save_weights('./checkpoint/weights.ckpt')
print('save weights')
if __name__ == '__main__':
main()
ResNet网络相较于一般的卷积网络加入了特有的短路结构,从而使得ResNet网络突破了网络层数的限制,为高级语义和特征提取提供了可能。
