1. 在磁盤中保存與加載模型
1.1 保存與加載整個模型
保存整個模型:
- 模型的架構/配置
- 模型的權重值(在訓練過程中學習)
- 模型的編譯信息(如果調用了
compile()
) - 優化器及其狀態(如果有的話,使您可以從上次中斷的位置重新開始訓練)
保存模型
model.save(filepath)
或者
tf.keras.models.save_model(model, filepath)
注意:filepath
的文件格式,如果不加后綴,默認是SavedModel
格式,如果加后綴.h5
,則是HDF5
格式。后者相比前者更加輕量化,但包含內容不如前者。
加載模型
tf.keras.models.load_model(filepath)
注意:如果加載的是h5
格式文件,那么可能會報錯:AttributeError: ‘str’ object has no attribute 'decode
。這是由於h5py
版本過高導致,可以安裝只能版本的h5py
,即pip install h5py==2.10.0
。
舉例
import tensorflow as tf
from tensorflow import keras
def get_model():
model = keras.Sequential()
model.add(keras.Input(shape=(1,)))
model.add(keras.layers.Dense(10, keras.activations.relu))
model.add(keras.layers.Dense(1))
model.compile(optimizer='sgd', loss='mse')
return model
model_1 = get_model()
model_1.save("my_model.h5")
# 或者 model_1.save("my_model")
model_2 = tf.keras.models.load_model("my_model.h5")
# 或者 model_2 = tf.keras.models.load_model("my_model")
1.2 只保存與加載參數
保存參數
model.save_weights(filepath)
注意:filepath
的文件格式,如果不加后綴,默認是TensorFlow Checkpoint
格式,如果加后綴.h5
,則是HDF5
格式。具體差別可看官方文檔。當網絡存在嵌套時,后者可能會有問題。
加載參數
model.load_weights(filepath)
舉例
import tensorflow as tf
from tensorflow import keras
def get_model():
model = keras.Sequential()
model.add(keras.Input(shape=(1,)))
model.add(keras.layers.Dense(10, keras.activations.relu))
model.add(keras.layers.Dense(1))
model.compile(optimizer='sgd', loss='mse')
return model
model_1 = get_model()
model_1.save_weights("my_model_weights.h5")
# 或者 model_1.save_weights("my_model_weights")
model_1.load_weights("my_model_weights.h5")
# 或者 model_1.load_weights("my_model_weights.h5")
使用回調函數
使用回調函數同樣也可以保存和加載模型參數
在訓練時加入ModelCheckpoint
回調函數:
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=文件路徑, # 文件存儲理解
save_weights_only=True/False, # 是否只保留參數
save_best_only=True/False # 是否只保留最優結果
)
model.fit(
...
callbacks=[cp_callback]
)
# 加載模型參數
model.load_weights(文件路徑)
2. 在內存中克隆模型
2.1 克隆整個模型
keras.models.clone_model(model)
注意:這里的model
只能是functional model
或 sequential model
,不能是subclass model
2.2 只克隆參數
獲取摸個模型的參數
model.get_weights()
給某個模型的參數賦值
model.set_weights(weights)
參考
https://tensorflow.google.cn/guide/keras/save_and_serialize
https://www.bilibili.com/video/BV1B7411L7Qt?p=22