任務簡述:最近做一個圖像分類的任務, 一開始拿vgg跑一個baseline,輸出看起來很正常:
隨后,我嘗試其他的一些經典的模型架構,比如resnet50, xception,但訓練輸出顯示明顯異常:
val_loss 一直亂蹦,val_acc基本不發生變化。
檢查了輸入數據沒發現問題,因此懷疑是網絡構造有問題, 對比了vgg同xception, resnet在使用layer上的異同,認為問題可能出在BN層上,將vgg添加了BN層之后再訓練果然翻車。
翻看keras BN 的源碼, 原來keras 的BN層的call函數里面有個默認參數traing, 默認是None。此參數意義如下:
training=False/0, 訓練時通過每個batch的移動平均的均值、方差去做批歸一化,測試時拿整個訓練集的均值、方差做歸一化
training=True/1/None,訓練時通過當前batch的均值、方差去做批歸一化,測試時拿整個訓練集的均值、方差做歸一化
當training=None時,訓練和測試的批歸一化方式不一致,導致validation的輸出指標翻車。
當training=True時,拿訓練完的模型預測一個樣本和預測一個batch的樣本的差異非常大,也就是預測的結果根據batch的大小會不同!導致模型結果無法准確評估!也是個坑!
用keras的BN時切記要設置training=False!!!
def build_model(): Inputs = Input(shape=intput_shape, name='input') x_tmp = Lambda(lambda c: tf.image.rgb_to_grayscale(c))(Inputs) x_tmp = Conv2D(64, (3, 3), activation='relu')(x_tmp) x_tmp = Conv2D(64, (3, 3), activation='relu')(x_tmp) x_tmp = BatchNormalization(x_tmp, training=False) x_tmp = MaxPooling2D(pool_size=(2, 2))(x_tmp) x_tmp = Flatten()(x_tmp) x_tmp = Dense(128, activation='relu')(x_tmp) outputs = Dense(10, activation='softmax')(x_tmp) model = Model(Inputs, outputs) return model
參考:
https://arxiv.org/pdf/1502.03167v3.pdf
https://github.com/keras-team/keras/blob/master/keras/layers/normalization.py#L16