tf2 模型保存總結


tf2 模型保存總結

1. model.save保存的是所有信息,結果是單文件,最為簡單。

實例:保

model_name = "./model_save/fassionMnist_save.h5"

model.save(model_name)

new_model = keras.models.load_model(model_name)

2. model.save_weights(weight_file)保存的是權重,結果是單文件。

weight_file="./model_save/weights.h5"

示例:保

model.save_weights(weight_file)

 

model = keras.Sequential()

model.add(keras.layers.Flatten(input_shape=(28,28)))

model.add(keras.layers.Dense(128,activation="relu"))

model.add(keras.layers.Dense(10, activation="softmax"))

model.summary()

 

model.compile(optimizer="adam",

loss="sparse_categorical_crossentropy",

metrics=["acc"])

 

model.load_weights(weight_file)

3. 檢查點保存權重,結果多文件

示例:

ckpt_path="./ckpt/model_ckpt.ckpt"

ckpt_callback=keras.callbacks.ModelCheckpoint(

ckpt_path,save_weights_only=True)

history = model.fit(train_image,train_label,epochs=3,callbacks=[ckpt_callback])

 

model = keras.Sequential()

model.add(keras.layers.Flatten(input_shape=(28,28)))

model.add(keras.layers.Dense(128,activation="relu"))

model.add(keras.layers.Dense(10, activation="softmax"))

model.summary()

 

model.compile(optimizer="adam",

loss="sparse_categorical_crossentropy",

metrics=["acc"])

 

model.load_weights(ckpt_path)

 

4. 檢查點保存全部模型,結果是文件夾

而且win下保存路徑必須用 反斜杠,不能用正斜杠,可視為bug

model_ckpt_path=".\ckpt\model3.model"

ckpt_callback=keras.callbacks.ModelCheckpoint(

model_ckpt_path,save_weights_only=False)

model.evaluate(test_image,test_label,verbose=0)

history = model.fit(train_image,train_label,epochs=3,callbacks=[ckpt_callback])

model.evaluate(test_image,test_label,verbose=0)

 

new_model = keras.models.load_model(model_ckpt_path)

new_model.evaluate(test_image,test_label,verbose=0)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM