1.多輸入、多輸出
模型某一層接收多輸入數據,以實現共享該層參數的目的。如對title和desc做文本分類,兩類可以共享一個embedding數據,進而獲取某種關聯特征,示例代碼如下:
title = Input(shape=(30,),name="title") desc = Input(shape=(200,),name="desc") # title和desc 共享 mebedding layer embedding = Embedding(3000, 128) title_embedd = embedding(title) desc_embedd = embedding(desc) title_lstm = LSTM(128)(title_embedd) desc_lstm = LSTM(128)(desc_embedd) out_title = Dense(1,activation="sigmoid",name="out_title")(title_lstm) out_desc = Dense(1,activation="sigmoid",name="out_desc")(desc_lstm) model = Model(inputs=[title,desc],outputs=[out_title,out_desc]) keras.utils.plot_model(model, show_shapes=True)
打印model:
2.不同輸出設置不同的類型loss和weights
# model compile model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"]) # 輸入和輸出有多個,喂數據時整理成list形式,對應好 model.fit([title_input, desc_input],[title_out, desc_out]) # 不同的輸出設置不同的loss和權重 model.compile(loss={"out_title":"binary_crossentropy", "out_desc":"categorical_crossentropy"}, optimizer="adam", metrics=["accuracy"]) model.compile(loss={"out_title":"binary_crossentropy", "out_desc":"categorical_crossentropy"}, loss_weights={"out_title":0.3,"out_desc":0.8},optimizer="adam",metrics=["accuracy"])
注:根據輸出的名稱對應設置類型,Keras這種思路無處不在
Keras API:https://keras.io/api/models/model_training_apis/