keras中保存自定義層和loss


在keras中保存模型有幾種方式:

(1):使用callbacks,可以保存訓練中任意的模型,或選擇最好的模型

logdir = './callbacks'
if not os.path.exists(logdir):
    os.mkdir(logdir)
output_model_file = os.path.join(logdir, "xxxx.h5")
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(output_model_file, save_best_file = True)
]

hist = model.fit_generator(xxxxx, callbacks = callbacks)

(2): 使用model.save(),會把整個模型保存下來,包括網絡和參數

(3): 使用model.save_weights(),只保存模型的參數

當使用自定義的層或loss時,只有(3)可以直接使用,1 2會報下面這種錯:

NotImplementedError: Layers with arguments in `__init__` must override `get_config`.
ValueError: Unknown loss function:loss
ValueError: Unknown layer: xxxlayer

解決辦法:

在自定義網絡層時重寫get_config函數

我們主要看傳入__init__接口時有哪些配置參數,然后在get_config內一一的將它們轉為字典鍵值並且返回使用,以Mylayer為例:

class MyLayer(tf.keras.layers.Layer):
    def __init__(self, num_outputs, name="MyLayer", **kwargs):
        super(MyLayer, self).__init__(name=name, **kwargs)
        self.num_outputs = num_outputs

    def build(self, input_shape):
        self.kernel = self.add_variable("kernel", shape=[int(input_shape[-1]), self.num_outputs])
        super().build(input_shape)

    def call(self, input):
        output = tf.matmul(input, self.kernel)
        return output

    def get_config(self):
       config = {"num_outputs":self.num_outputs}
       base_config = super(Mylayer, self).get_config()
       return dict(list(base_config.items()) + list(config.items()))

一般來說,父類的config也是需要一並保存的,其中base_config即是父類網絡層實現的配置參數,最后把父類及繼承類的config組裝為字典形式即可解決該問題

然后 在加載模型的時候,建立一個字典,該字典的鍵是自定義網絡層時設定該層的名字,其值為該自定義網絡層的類名,該字典將用於加載模型時使用

如果還使用了自定義的loss,則把loss也加到_custom_objects中

_custom_objects = {
    "Mylayer" :  Mylayer,
   "loss" : Myloss
}

最后在load模型的時候把_custom_objects傳入

model = tf.keras.models.load_model("path/to/your/model", custom_objects=_custom_objects)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM