Keras模型拼裝


在訓練較大網絡時, 往往想加載預訓練的模型, 但若想在網絡結構上做些添補, 可能出現問題一二...

一下是添補的幾種情形, 此處以單輸出回歸任務為例:

# 添在末尾:
base_model = InceptionV3(weights='imagenet', include_top=False)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1, activation='relu')(x)

model = Model(inputs=base_model.input, outputs=x)
model.summary()
# 添在開頭和末尾:
# 在開頭加1x1卷積層, 使4通道降為3通道, 再傳入InceptionV3
def head_model(input_shape=(150, 150, 4)):
    input_tensor = Input(input_shape)
    x = Conv2D(128, (1, 1), activation='relu')(input_tensor)
    x = Conv2D(3, (1, 1), activation='relu')(x)
    model = Model(inputs=input_tensor, outputs=x, name='head')
    return model

head_model = head_model()
body_model = InceptionV3(weights='imagenet', include_top=False)
base_model = Model(head_model.input, body_model(head_model.output))
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1, activation='relu')(x)

model = Model(inputs=base_model.inputs, outputs=x, name='net')
base_model.summary()
# 兩數據輸入流合並於末尾:
base_model = InceptionV3(weights='imagenet', include_top=False, input_shape=(150, 150, 3))
flat = Flatten()(base_model.output)
input_K = Input((100, ))    # another_input
K_flow = Activation(activation='linear')(input_K)
x = concatenate([flat, K_flow])    # 合流
x = Dense(1024, activation='relu')(x)
x = Dense(512, activation='relu')(x)
x = Dense(1, activation='relu')(x)
model = Model(inputs=[*base_model.inputs, input_K], outputs=x)    # 數據生成器那里也以這種形式生成([x_0, x_1], y)即可.
model.summary()

References:
末尾
開頭
末尾合流_0 末尾合流_1

附相關問題:

#開頭
在名為convXd_Y的shape得到的是(a, b, c, d), 但本應該為(z, y, x, w) -- 在確保沒有模型拼接時的低級錯誤時, 可嘗試將在pre-trained的模型前的那幾層, 如Conv2D層, 賦以如name='head_conv_0'等 與 框架默認賦予的形如convXd_Y 不會沖突的名字, 不然按默認的來, pre-trained的模型中的第一個卷積層屬性會被賦予你開頭新添加的第一個卷積層中, 故生上錯. 但其實, 你也可以手動先從pre-trained的模型中get_weights(), 繼而逐層往新搭建的模型里set_weights(), 詳見Keras相關文檔.

#末尾合流
ValueError: The shape of the input to "Flatten" is not fully defined (got (None, None, 512)_1
ValueError: The shape of the input to "Flatten" is not fully defined (got (None, None, 512)_2


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM