TensorFlow2中Keras模型保存與加載


主要記錄在Tensorflow2中使用Keras API接口,有關模型保存、加載的內容;

0. 加載數據、構建網絡

首先,為了方便后續有關模型保存、加載相關代碼的正常執行,這里加載mnist數據集、構建一個簡單的網絡結構。

import tensorflow as tf
from libs.load_keras_dataset import load_mnist

注意:下面引入mnist數據集的方式,僅為了方便作者從本地加載、使用;

mnist_path = '/home/chenz/data/mnist/mnist.npz'
(x_train, y_train), (x_test, y_test) = load_mnist(data_path=mnist_path)
print("[INFO] x_train: {}, y_train: {}, x_test: {}, y_test: {}".format(
    x_train.shape, y_train.shape, x_test.shape, y_test.shape
))
train_labels = y_train[:1000]
test_labels = y_test[:1000]
train_images = x_train[:1000].reshape(-1, 28*28) / 255.0
test_images = x_test[:1000].reshape(-1, 28*28) / 255.0

print("[INFO] train_images: {}, train_labels: {}, test_images: {}, test_labels: {}".format(
    train_images.shape, train_labels.shape, test_images.shape, test_labels.shape
))
[INFO] x_train: (60000, 28, 28), y_train: (60000,), x_test: (10000, 28, 28), y_test: (10000,)
[INFO] train_images: (1000, 784), train_labels: (1000,), test_images: (1000, 784), test_labels: (1000,)

定義一個方法,用於構建網絡結構,並定義網絡編譯方式,方便后續使用;

# Build Model
def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(512, activation='relu', input_shape=(784,)),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10)
    ])
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.1, beta_2=0.2, amsgrad=True),
                  loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model

1. model.save() & model.save_weights()

在TensorFlow的Keras API中提供了兩種保存模型的方式,分別為model.save()model.save_weights(),從字面上可以簡單理解,后者僅保存網絡結構權重,前者能夠保存整個模型結構

進一步,從源碼文檔中可以理清兩者的區別:

1.1 model.save()

該方法能夠將整個模型進行保存,以兩種方式存儲,Tensorflow SavedModelHDF file,保存的文件包括:

  • 模型結構,能夠重新實例化模型;
  • 模型權重;
  • 優化器的狀態,在上次中斷的地方繼續訓練;

可以通過tf.keras.models.load_model重新實例化保存的模型,通過該方法返回的模型是已經編譯過的模型,除非在之前保存模型的時候就沒有被編譯;

利用SequentialFunctional兩種形式構建的網絡都能夠保存成HDF5和SavedModel格式,但是Subclasses形式的模型僅能夠保存成SavedModel格式;

# HDF5格式
model_name.h5

# Tensorflow SavedModel格式
./saved_model
	assets/
	saved_model.pb
	variables/

使用參數說明:

def save(self,
           filepath,
           overwrite=True,
           include_optimizer=True,
           save_format=None,
           signatures=None,
           options=None):
  • filepath表示模型存儲的路徑;

  • save_format表示以tf或者h5形式進行存儲,在TF2中默認tf,TF1中默認h5

  • overwrite表示是否覆蓋在目標目錄下的已有文件;

  • include_optimizer表示是否保存優化器的狀態;

  • signatures僅用於tf形式,具體使用見tf.saved_model.save

filepathsave_format結合在一起使用,有如下組合方式:

  • filepath.h5為結尾的文件名,則不論save_formattf或者h5,則模型將保存成filename.h5形式;(上級目錄需要存在)
  • filepath僅指定文件名,save_format='h5',則模型將保存成filename的HDF形式;
  • filepath指定路徑(需存在),save_format='tf',則模型將以Tensorflow SavedModel形式保存到指定路徑下;

注意:filepath不包含后綴時,注意區分是文件目錄還是文件名,以tf形式保存,則需要存在指定路徑,以h5形式保存,則不能存在相同名稱路徑;

1.2 model.save_weights()

該方法僅保存網絡中所有層的權重,

# HDF5格式
weights_2 or weights_3.h5

# Tensorflow 格式
checkpoint 
weiths_1.data-00000-of-00001
weigths_1.index

使用參數說明:

def save_weights(self,
                 filepath,
                 overwrite=True,
                 save_format=None,
                 options=None):
  • filepath表示存儲的模型文件名或路徑;
  • save_format用於表示存儲格式,HDF5或者Tensorflow格式;

filepathsave_format結合使用:

  • filepath以后綴.h5或者.keras結尾,設置save_format=None或者save_format=None,模型將保存成filename.h5filename.keras格式;
  • filepath不含后綴,如果save_format='h5',則模型保存成filename
  • filepath不含后綴,如果save_format='tf'或者save_format=None,則模型保存成Tensorflow格式;

2. tf.keras.callbacks.ModelCheckpoint

該方法以回調函數的形式,在模型訓練過程中保存模型。

def __init__(self,
             filepath,
             monitor='val_loss',
             verbose=0,
             save_best_only=False,
             save_weights_only=False,
             mode='auto',
             save_freq='epoch',
             options=None,
             **kwargs):

這里僅提及一點,就是在使用參數save_weights_only時:

  • 設置True,則調用model.save_weights()
  • 設置False,則調用model.save()

使用方式:

checkpoint_path = "./saved_model/save_and_load/cp_test_1/cp.ckpt"
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=False,
                                                 verbose=1)
model.fit(train_images, train_labels,
          epochs=3,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])

3. tf.keras.models.load_model、model.load_weights

上面簡單說明了模型保存的兩種方式,一種是保存整個模型,另一種則是僅保存模型權重;

完整的模型可以使用tf.keras.models.load_model加載,只包含權重的模型則使用model.load_weights加載;

3.1 tf.keras.models.load_model

加載完整模型

model_path = './saved_model/save_and_load/save_test/test_5/'

model = tf.keras.models.load_model(model_path)
model.summary()
  • 其中,model_path可以為.h5文件的路徑,或者Tensorflow SavedModel的路徑

3.2 model.load_weights

在重新構建網絡的基礎上,加載模型權重;

model = create_model()
model.load_weights("./saved_model/save_and_load/save_test/weights/weights_1")
model.summary()

4. 總結

  • 官方API是推薦Tensorflow格式進行保存模型,不論是保存整個模型,或是僅保存權重;


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM