在keras中保存模型有幾種方式:
(1):使用callbacks,可以保存訓練中任意的模型,或選擇最好的模型
logdir = './callbacks'
if not os.path.exists(logdir):
os.mkdir(logdir)
output_model_file = os.path.join(logdir, "xxxx.h5")
callbacks = [
tf.keras.callbacks.ModelCheckpoint(output_model_file, save_best_file = True)
]
hist = model.fit_generator(xxxxx, callbacks = callbacks)
(2): 使用model.save(),會把整個模型保存下來,包括網絡和參數
(3): 使用model.save_weights(),只保存模型的參數
當使用自定義的層或loss時,只有(3)可以直接使用,1 2會報下面這種錯:
NotImplementedError: Layers with arguments in `__init__` must override `get_config`.
ValueError: Unknown loss function:loss
ValueError: Unknown layer: xxxlayer
解決辦法:
在自定義網絡層時重寫get_config函數
我們主要看傳入__init__接口時有哪些配置參數,然后在get_config內一一的將它們轉為字典鍵值並且返回使用,以Mylayer為例:
class MyLayer(tf.keras.layers.Layer):
def __init__(self, num_outputs, name="MyLayer", **kwargs):
super(MyLayer, self).__init__(name=name, **kwargs)
self.num_outputs = num_outputs
def build(self, input_shape):
self.kernel = self.add_variable("kernel", shape=[int(input_shape[-1]), self.num_outputs])
super().build(input_shape)
def call(self, input):
output = tf.matmul(input, self.kernel)
return output
def get_config(self):
config = {"num_outputs":self.num_outputs}
base_config = super(Mylayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
一般來說,父類的config也是需要一並保存的,其中base_config即是父類網絡層實現的配置參數,最后把父類及繼承類的config組裝為字典形式即可解決該問題
然后 在加載模型的時候,建立一個字典,該字典的鍵是自定義網絡層時設定該層的名字,其值為該自定義網絡層的類名,該字典將用於加載模型時使用
如果還使用了自定義的loss,則把loss也加到_custom_objects中
_custom_objects = {
"Mylayer" : Mylayer,
"loss" : Myloss
}
最后在load模型的時候把_custom_objects傳入
model = tf.keras.models.load_model("path/to/your/model", custom_objects=_custom_objects)