Keras通過子類(subclass)自定義神經網絡模型


參考文獻:Géron, Aurélien. Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow: Concepts, Tools, and Techniques to Build Intelligent Systems. Reilly Media, 2019.

除了使用函數API外,還可以通過子類(subclass)自定義神經網絡模型。

假設要搭建如圖所示的神經網格,使用函數API:

input_A = keras.layers.Input(shape=[5], name="wide_input")
input_B = keras.layers.Input(shape=[6], name="deep_input")
hidden1 = keras.layers.Dense(30, activation="relu")(input_B)
hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
concat = keras.layers.concatenate([input_A, hidden2])
output = keras.layers.Dense(1, name="main_output")(concat)
aux_output = keras.layers.Dense(1, name="aux_output")(hidden2)
model = keras.models.Model(inputs=[input_A, input_B],
                           outputs=[output, aux_output])

換成子類API,

class WideAndDeepModel(keras.models.Model):
    def __init__(self, units=30, activation="relu", **kwargs):
        super().__init__(**kwargs)
        self.hidden1 = keras.layers.Dense(units, activation=activation)
        self.hidden2 = keras.layers.Dense(units, activation=activation)
        self.main_output = keras.layers.Dense(1)
        self.aux_output = keras.layers.Dense(1)
        
    def call(self, inputs):
        input_A, input_B = inputs
        hidden1 = self.hidden1(input_B)
        hidden2 = self.hidden2(hidden1)
        concat = keras.layers.concatenate([input_A, hidden2])
        main_output = self.main_output(concat)
        aux_output = self.aux_output(hidden2)
        return main_output, aux_output

初始化模型並編譯

model = WideAndDeepModel(30, activation="relu")
model.compile(loss="mse", loss_weights=[0.9, 0.1], optimizer=keras.optimizers.SGD(lr=1e-3))

和函數式API不同,使用子類搭建的神經網絡,如果運行model.summary,系統會報錯

ValueError: This model has not yet been built. Build the model first by calling `build()` or calling `fit()` with some data, or specify an `input_shape` argument in the first layer(s) for automatic build.

這是因為通過子類搭建的網絡中不存在graph,即沒有網絡層之間的連接信息,因此無法使用model.summary() 。如果想要使用model.summary(),有兩種方法:
第一種方法比較別扭,就是先讀入數據訓練一次,

history = model.fit((X_train_A, X_train_B), (y_train, y_train), epochs=2,
                    validation_data=((X_valid_A, X_valid_B), (y_valid, y_valid)))

再運行model.summary就可以輸出模型信息

Model: "wide_and_deep_model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_28 (Dense)             multiple                  210       
_________________________________________________________________
dense_29 (Dense)             multiple                  930       
_________________________________________________________________
dense_30 (Dense)             multiple                  36        
_________________________________________________________________
dense_31 (Dense)             multiple                  31        
=================================================================
Total params: 1,207
Trainable params: 1,207
Non-trainable params: 0
_________________________________________________________________

不同於通過子類API搭建的模型,使用model.summary()無法輸出神經網絡的詳細信息,這是子類API的缺點。
第二種方法其實報錯信息里提示,就是需要先運行一次模型build,輸入神經網絡的input shape。需要注意的是,這是一個Multi-Inputs神經網格,因此input shape是一個列表

model.build([(None, 5),(None, 6)])

之后再運行一次model.summary()就不會報錯。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM