Keras模型多輸入-多輸出設計思路


1.多輸入、多輸出

模型某一層接收多輸入數據,以實現共享該層參數的目的。如對title和desc做文本分類,兩類可以共享一個embedding數據,進而獲取某種關聯特征,示例代碼如下:

title = Input(shape=(30,),name="title")
desc = Input(shape=(200,),name="desc")
# title和desc 共享 mebedding layer
embedding = Embedding(3000, 128)

title_embedd = embedding(title)
desc_embedd = embedding(desc)
title_lstm = LSTM(128)(title_embedd)
desc_lstm = LSTM(128)(desc_embedd)

out_title = Dense(1,activation="sigmoid",name="out_title")(title_lstm)
out_desc = Dense(1,activation="sigmoid",name="out_desc")(desc_lstm)
model = Model(inputs=[title,desc],outputs=[out_title,out_desc])
keras.utils.plot_model(model, show_shapes=True)

打印model:

 

2.不同輸出設置不同的類型loss和weights

# model compile
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])
# 輸入和輸出有多個,喂數據時整理成list形式,對應好
model.fit([title_input, desc_input],[title_out, desc_out])
# 不同的輸出設置不同的loss和權重
model.compile(loss={"out_title":"binary_crossentropy", "out_desc":"categorical_crossentropy"}, optimizer="adam", metrics=["accuracy"])
model.compile(loss={"out_title":"binary_crossentropy", "out_desc":"categorical_crossentropy"}, loss_weights={"out_title":0.3,"out_desc":0.8},optimizer="adam",metrics=["accuracy"])

注:根據輸出的名稱對應設置類型,Keras這種思路無處不在

 Keras API:https://keras.io/api/models/model_training_apis/


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM