主要記錄在Tensorflow2中使用Keras API接口,有關模型保存、加載的內容;
0. 加載數據、構建網絡
首先,為了方便后續有關模型保存、加載相關代碼的正常執行,這里加載mnist數據集、構建一個簡單的網絡結構。
import tensorflow as tf
from libs.load_keras_dataset import load_mnist
注意:下面引入mnist數據集的方式,僅為了方便作者從本地加載、使用;
mnist_path = '/home/chenz/data/mnist/mnist.npz'
(x_train, y_train), (x_test, y_test) = load_mnist(data_path=mnist_path)
print("[INFO] x_train: {}, y_train: {}, x_test: {}, y_test: {}".format(
x_train.shape, y_train.shape, x_test.shape, y_test.shape
))
train_labels = y_train[:1000]
test_labels = y_test[:1000]
train_images = x_train[:1000].reshape(-1, 28*28) / 255.0
test_images = x_test[:1000].reshape(-1, 28*28) / 255.0
print("[INFO] train_images: {}, train_labels: {}, test_images: {}, test_labels: {}".format(
train_images.shape, train_labels.shape, test_images.shape, test_labels.shape
))
[INFO] x_train: (60000, 28, 28), y_train: (60000,), x_test: (10000, 28, 28), y_test: (10000,)
[INFO] train_images: (1000, 784), train_labels: (1000,), test_images: (1000, 784), test_labels: (1000,)
定義一個方法,用於構建網絡結構,並定義網絡編譯方式,方便后續使用;
# Build Model
def create_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(512, activation='relu', input_shape=(784,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.1, beta_2=0.2, amsgrad=True),
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
1. model.save() & model.save_weights()
在TensorFlow的Keras API中提供了兩種保存模型的方式,分別為model.save()
、model.save_weights()
,從字面上可以簡單理解,后者僅保存網絡結構權重,前者能夠保存整個模型結構;
進一步,從源碼文檔中可以理清兩者的區別:
1.1 model.save()
該方法能夠將整個模型進行保存,以兩種方式存儲,Tensorflow SavedModel
、HDF file
,保存的文件包括:
- 模型結構,能夠重新實例化模型;
- 模型權重;
- 優化器的狀態,在上次中斷的地方繼續訓練;
可以通過tf.keras.models.load_model
重新實例化保存的模型,通過該方法返回的模型是已經編譯過的模型,除非在之前保存模型的時候就沒有被編譯;
利用Sequential
和Functional
兩種形式構建的網絡都能夠保存成HDF5和SavedModel格式,但是Subclasses
形式的模型僅能夠保存成SavedModel格式;
# HDF5格式
model_name.h5
# Tensorflow SavedModel格式
./saved_model
assets/
saved_model.pb
variables/
使用參數說明:
def save(self,
filepath,
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None,
options=None):
-
filepath
表示模型存儲的路徑; -
save_format
表示以tf
或者h5
形式進行存儲,在TF2中默認tf
,TF1中默認h5
; -
overwrite
表示是否覆蓋在目標目錄下的已有文件; -
include_optimizer
表示是否保存優化器的狀態; -
signatures
僅用於tf
形式,具體使用見tf.saved_model.save
;
filepath
和save_format
結合在一起使用,有如下組合方式:
filepath
以.h5
為結尾的文件名,則不論save_format
是tf
或者h5
,則模型將保存成filename.h5
形式;(上級目錄需要存在)filepath
僅指定文件名,save_format='h5'
,則模型將保存成filename
的HDF形式;filepath
指定路徑(需存在),save_format='tf'
,則模型將以Tensorflow SavedModel
形式保存到指定路徑下;
注意:filepath
不包含后綴時,注意區分是文件目錄還是文件名,以tf
形式保存,則需要存在指定路徑,以h5
形式保存,則不能存在相同名稱路徑;
1.2 model.save_weights()
該方法僅保存網絡中所有層的權重,
# HDF5格式
weights_2 or weights_3.h5
# Tensorflow 格式
checkpoint
weiths_1.data-00000-of-00001
weigths_1.index
使用參數說明:
def save_weights(self,
filepath,
overwrite=True,
save_format=None,
options=None):
filepath
表示存儲的模型文件名或路徑;save_format
用於表示存儲格式,HDF5
或者Tensorflow
格式;
filepath
與save_format
結合使用:
filepath
以后綴.h5
或者.keras
結尾,設置save_format=None
或者save_format=None
,模型將保存成filename.h5
或filename.keras
格式;filepath
不含后綴,如果save_format='h5'
,則模型保存成filename
;filepath
不含后綴,如果save_format='tf'
或者save_format=None
,則模型保存成Tensorflow
格式;
2. tf.keras.callbacks.ModelCheckpoint
該方法以回調函數的形式,在模型訓練過程中保存模型。
def __init__(self,
filepath,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
save_freq='epoch',
options=None,
**kwargs):
這里僅提及一點,就是在使用參數save_weights_only
時:
- 設置
True
,則調用model.save_weights()
; - 設置
False
,則調用model.save()
;
使用方式:
checkpoint_path = "./saved_model/save_and_load/cp_test_1/cp.ckpt"
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=False,
verbose=1)
model.fit(train_images, train_labels,
epochs=3,
validation_data=(test_images, test_labels),
callbacks=[cp_callback])
3. tf.keras.models.load_model、model.load_weights
上面簡單說明了模型保存的兩種方式,一種是保存整個模型,另一種則是僅保存模型權重;
完整的模型可以使用tf.keras.models.load_model
加載,只包含權重的模型則使用model.load_weights
加載;
3.1 tf.keras.models.load_model
加載完整模型
model_path = './saved_model/save_and_load/save_test/test_5/'
model = tf.keras.models.load_model(model_path)
model.summary()
- 其中,
model_path
可以為.h5
文件的路徑,或者Tensorflow SavedModel
的路徑
3.2 model.load_weights
在重新構建網絡的基礎上,加載模型權重;
model = create_model()
model.load_weights("./saved_model/save_and_load/save_test/weights/weights_1")
model.summary()
4. 總結
- 官方API是推薦Tensorflow格式進行保存模型,不論是保存整個模型,或是僅保存權重;