keras BatchNormalization 之坑


任務簡述:最近做一個圖像分類的任務, 一開始拿vgg跑一個baseline,輸出看起來很正常:

 

 

隨后,我嘗試其他的一些經典的模型架構,比如resnet50, xception,但訓練輸出顯示明顯異常:

 

val_loss 一直亂蹦,val_acc基本不發生變化。

檢查了輸入數據沒發現問題,因此懷疑是網絡構造有問題, 對比了vgg同xception, resnet在使用layer上的異同,認為問題可能出在BN層上,將vgg添加了BN層之后再訓練果然翻車。

 

翻看keras BN 的源碼, 原來keras 的BN層的call函數里面有個默認參數traing, 默認是None。此參數意義如下:

training=False/0, 訓練時通過每個batch的移動平均的均值、方差去做批歸一化,測試時拿整個訓練集的均值、方差做歸一化

training=True/1/None,訓練時通過當前batch的均值、方差去做批歸一化,測試時拿整個訓練集的均值、方差做歸一化

 

 當training=None時,訓練和測試的批歸一化方式不一致,導致validation的輸出指標翻車。

當training=True時,拿訓練完的模型預測一個樣本和預測一個batch的樣本的差異非常大,也就是預測的結果根據batch的大小會不同!導致模型結果無法准確評估!也是個坑!

 

用keras的BN時切記要設置training=False!!!

def build_model():
    Inputs = Input(shape=intput_shape, name='input')
    x_tmp = Lambda(lambda c: tf.image.rgb_to_grayscale(c))(Inputs)
    x_tmp = Conv2D(64, (3, 3), activation='relu')(x_tmp)
    x_tmp = Conv2D(64, (3, 3), activation='relu')(x_tmp)
    x_tmp = BatchNormalization(x_tmp, training=False)
    x_tmp = MaxPooling2D(pool_size=(2, 2))(x_tmp)

    x_tmp = Flatten()(x_tmp)
    x_tmp = Dense(128, activation='relu')(x_tmp)
    outputs = Dense(10, activation='softmax')(x_tmp)
    model = Model(Inputs, outputs)
    return model

 

參考:

https://arxiv.org/pdf/1502.03167v3.pdf

https://github.com/keras-team/keras/blob/master/keras/layers/normalization.py#L16

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM