tensorflow2.0中保存、加載、克隆模型


1. 在磁盤中保存與加載模型

1.1 保存與加載整個模型

保存整個模型:

  • 模型的架構/配置
  • 模型的權重值(在訓練過程中學習)
  • 模型的編譯信息(如果調用了 compile()
  • 優化器及其狀態(如果有的話,使您可以從上次中斷的位置重新開始訓練)

保存模型

model.save(filepath)

或者

tf.keras.models.save_model(model, filepath)

注意:filepath的文件格式,如果不加后綴,默認是SavedModel格式,如果加后綴.h5,則是HDF5格式。后者相比前者更加輕量化,但包含內容不如前者。

加載模型

tf.keras.models.load_model(filepath)

注意:如果加載的是h5格式文件,那么可能會報錯:AttributeError: ‘str’ object has no attribute 'decode。這是由於h5py版本過高導致,可以安裝只能版本的h5py,即pip install h5py==2.10.0

舉例

import tensorflow as tf
from tensorflow import keras

def get_model():
    model = keras.Sequential()
    model.add(keras.Input(shape=(1,)))
    model.add(keras.layers.Dense(10, keras.activations.relu))
    model.add(keras.layers.Dense(1))
    model.compile(optimizer='sgd',  loss='mse')
    return model

model_1 = get_model()

model_1.save("my_model.h5")
# 或者 model_1.save("my_model")

model_2 = tf.keras.models.load_model("my_model.h5")
# 或者 model_2 = tf.keras.models.load_model("my_model")

1.2 只保存與加載參數

保存參數

model.save_weights(filepath)

注意:filepath的文件格式,如果不加后綴,默認是TensorFlow Checkpoint格式,如果加后綴.h5,則是HDF5格式。具體差別可看官方文檔。當網絡存在嵌套時,后者可能會有問題。

加載參數

model.load_weights(filepath)

舉例

import tensorflow as tf
from tensorflow import keras

def get_model():
    model = keras.Sequential()
    model.add(keras.Input(shape=(1,)))
    model.add(keras.layers.Dense(10, keras.activations.relu))
    model.add(keras.layers.Dense(1))
    model.compile(optimizer='sgd',  loss='mse')
    return model

model_1 = get_model()

model_1.save_weights("my_model_weights.h5")
# 或者 model_1.save_weights("my_model_weights")

model_1.load_weights("my_model_weights.h5")
# 或者 model_1.load_weights("my_model_weights.h5")

使用回調函數

使用回調函數同樣也可以保存和加載模型參數

在訓練時加入ModelCheckpoint回調函數:

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=文件路徑,  # 文件存儲理解
    save_weights_only=True/False,  # 是否只保留參數
    save_best_only=True/False  # 是否只保留最優結果
)

model.fit(
  ...
  callbacks=[cp_callback]
)

# 加載模型參數
model.load_weights(文件路徑)

2. 在內存中克隆模型

2.1 克隆整個模型

keras.models.clone_model(model)

注意:這里的model只能是functional model sequential model,不能是subclass model

2.2 只克隆參數

獲取摸個模型的參數

model.get_weights()

給某個模型的參數賦值

model.set_weights(weights)

參考

https://tensorflow.google.cn/guide/keras/save_and_serialize
https://www.bilibili.com/video/BV1B7411L7Qt?p=22


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM