tensorflow2.0保存模型的方式有很多,這里只介紹兩種。
一、 使用官方模型
這種情況可以直接保存整個模型,如下所示,可以將模型保存為HDF5文件
# 創建模型實例
model = create_model()
# 保存模型到HDF5文件
model.save('my_model.h5')
# 讀取模型
model = keras.models.load_model('my_model.h5')
二、自定義模型
如果是自定義模型使用上述方法保存會報錯且保存失敗,報錯為:
NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn’t safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format=“tf”) or using
save_weights
.
這種情況可以保存weight。
# 創建模型
model = create_model()
# 保存權重
model.save_weights('model_weight')
# 創建新模型讀取權重
newModel = create_model()
# 讀取權重到新模型
newModel.load_weights('model_weight')
參考文獻: