參考文獻:Géron, Aurélien. Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow: Concepts, Tools, and Techniques to Build Intelligent Systems. Reilly Media, 2019.
除了使用函數API外,還可以通過子類(subclass)自定義神經網絡模型。
假設要搭建如圖所示的神經網格,使用函數API:
input_A = keras.layers.Input(shape=[5], name="wide_input")
input_B = keras.layers.Input(shape=[6], name="deep_input")
hidden1 = keras.layers.Dense(30, activation="relu")(input_B)
hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
concat = keras.layers.concatenate([input_A, hidden2])
output = keras.layers.Dense(1, name="main_output")(concat)
aux_output = keras.layers.Dense(1, name="aux_output")(hidden2)
model = keras.models.Model(inputs=[input_A, input_B],
outputs=[output, aux_output])
換成子類API,
class WideAndDeepModel(keras.models.Model):
def __init__(self, units=30, activation="relu", **kwargs):
super().__init__(**kwargs)
self.hidden1 = keras.layers.Dense(units, activation=activation)
self.hidden2 = keras.layers.Dense(units, activation=activation)
self.main_output = keras.layers.Dense(1)
self.aux_output = keras.layers.Dense(1)
def call(self, inputs):
input_A, input_B = inputs
hidden1 = self.hidden1(input_B)
hidden2 = self.hidden2(hidden1)
concat = keras.layers.concatenate([input_A, hidden2])
main_output = self.main_output(concat)
aux_output = self.aux_output(hidden2)
return main_output, aux_output
初始化模型並編譯
model = WideAndDeepModel(30, activation="relu")
model.compile(loss="mse", loss_weights=[0.9, 0.1], optimizer=keras.optimizers.SGD(lr=1e-3))
和函數式API不同,使用子類搭建的神經網絡,如果運行model.summary,系統會報錯
ValueError: This model has not yet been built. Build the model first by calling `build()` or calling `fit()` with some data, or specify an `input_shape` argument in the first layer(s) for automatic build.
這是因為通過子類搭建的網絡中不存在graph,即沒有網絡層之間的連接信息,因此無法使用model.summary()
。如果想要使用model.summary()
,有兩種方法:
第一種方法比較別扭,就是先讀入數據訓練一次,
history = model.fit((X_train_A, X_train_B), (y_train, y_train), epochs=2,
validation_data=((X_valid_A, X_valid_B), (y_valid, y_valid)))
再運行model.summary
就可以輸出模型信息
Model: "wide_and_deep_model_4"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_28 (Dense) multiple 210
_________________________________________________________________
dense_29 (Dense) multiple 930
_________________________________________________________________
dense_30 (Dense) multiple 36
_________________________________________________________________
dense_31 (Dense) multiple 31
=================================================================
Total params: 1,207
Trainable params: 1,207
Non-trainable params: 0
_________________________________________________________________
不同於通過子類API搭建的模型,使用model.summary()無法輸出神經網絡的詳細信息,這是子類API的缺點。
第二種方法其實報錯信息里提示,就是需要先運行一次模型build
,輸入神經網絡的input shape。需要注意的是,這是一個Multi-Inputs神經網格,因此input shape是一個列表
model.build([(None, 5),(None, 6)])
之后再運行一次model.summary()
就不會報錯。