断点续训,即在一次训练结束后,可以先将得到的最优训练参数保存起来,待到下次训练时,直接读取最优参数,在此基础上继续训练。
读取模型参数:
存储模型参数的文件格式为 ckpt(checkpoint)。
生成 ckpt 文件时,会同步生成索引表,所以可通过判断是否存在索引表来判断是否存在模型参数。
# 模型参数保存路径 checkpoint_save_path = "class4/MNIST_FC/checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + ".index"): model.load_weights(checkpoint_save_path)
保存模型参数:
# 定义回调函数,在模型训练时,回调函数会被执行,完成保留参数操作
cp_callback = tf.keras.callbacks.ModelCheckpoint( # 文件保存路径
filepath=checkpoint_save_path, # 是否只保留模型参数
save_weights_only=True, # 是否只保留最优结果
save_best_only=True ) # 执行训练过程,保存新的训练参数
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback])
代码:
import tensorflow as tf import os # 读取输入特征和标签
mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() # 数据归一化,减小计算量,方便神经网络吸收
x_train, x_test = x_train/255.0, x_test/255.0
# 声明网络结构
model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(10, activation="softmax") ]) # 配置训练方法
model.compile(optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=[tf.keras.metrics.sparse_categorical_accuracy]) # 如果存在参数文件,直接读取,在此基础上继续训练
checkpoint_save_path = "class4/MNIST_FC/checkpoint/mnist.ckpt" # 模型参数保存路径
if os.path.exists(checkpoint_save_path + ".index"): model.load_weights(checkpoint_save_path) # 定义回调函数,在模型训练时,完成保留参数操作
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True) # 执行训练过程,保存新的训练参数
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback]) # 打印网络结构和参数
model.summary()