以前使用Caffe的時候沒注意這個,現在使用預訓練模型來動手做時遇到了。在slim中的自帶模型中inception, resnet, mobilenet等都自帶BN層,這個坑在《實戰Google深度學習框架》第二版這本書P166里只是提了一句,沒有做出解答。
書中說訓練時和測試時使用的參數is_training都為True,然后給出了一個鏈接供參考。本人剛開始使用時也是按照書中的做法沒有改動,后來從保存后的checkpoint中加載模型做預測時出了問題:當改變需要預測數據的batchsize時預測的label也跟着變,這意味着checkpoint里面沒有保存訓練中BN層的參數,使用的BN層參數還是從需要預測的數據中計算而來的。這顯然會出問題,當預測的batchsize越大,假如你的預測數據集和訓練數據集的分布一致,結果就越接近於訓練結果,但如果batchsize=1,那BN層就發揮不了作用,結果很難看。
那如果在預測時is_traning=false呢,但BN層的參數沒有從訓練中保存,那使用的就是隨機初始化的參數,結果不堪想象。
所以需要在訓練時把BN層的參數保存下來,然后在預測時加載,參考幾位大佬的博客,有了以下訓練時添加的代碼:
1 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 2 with tf.control_dependencies(update_ops): 3 train_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss) 4 5 # 設置保存模型 6 var_list = tf.trainable_variables() 7 g_list = tf.global_variables() 8 bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name] 9 bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name] 10 var_list += bn_moving_vars 11 saver = tf.train.Saver(var_list=var_list, max_to_keep=5)
這樣就可以在預測時從checkpoint文件加載BN層的參數並設置is_training=False。
最后要說的是,雖然這么做可以解決這個問題,但也可以利用預測數據來計算BN層的參數,不是說一定要保存訓練時的參數,兩種方案可以作為超參數來調節使用,看哪種方法的結果更好。
感謝幾位大佬的博客解惑:
https://blog.csdn.net/dongjbstrong/article/details/80447110?utm_source=blogxgwz0
http://www.cnblogs.com/hrlnw/p/7227447.html