Tensorflow訓練和預測中的BN層的坑


  以前使用Caffe的時候沒注意這個,現在使用預訓練模型來動手做時遇到了。在slim中的自帶模型中inception, resnet, mobilenet等都自帶BN層,這個坑在《實戰Google深度學習框架》第二版這本書P166里只是提了一句,沒有做出解答。

  書中說訓練時和測試時使用的參數is_training都為True,然后給出了一個鏈接供參考。本人剛開始使用時也是按照書中的做法沒有改動,后來從保存后的checkpoint中加載模型做預測時出了問題:當改變需要預測數據的batchsize時預測的label也跟着變,這意味着checkpoint里面沒有保存訓練中BN層的參數,使用的BN層參數還是從需要預測的數據中計算而來的。這顯然會出問題,當預測的batchsize越大,假如你的預測數據集和訓練數據集的分布一致,結果就越接近於訓練結果,但如果batchsize=1,那BN層就發揮不了作用,結果很難看。

  那如果在預測時is_traning=false呢,但BN層的參數沒有從訓練中保存,那使用的就是隨機初始化的參數,結果不堪想象。

  所以需要在訓練時把BN層的參數保存下來,然后在預測時加載,參考幾位大佬的博客,有了以下訓練時添加的代碼:

 1 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
 2 with tf.control_dependencies(update_ops):
 3         train_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
 4 
 5 # 設置保存模型
 6 var_list = tf.trainable_variables()
 7 g_list = tf.global_variables()
 8 bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
 9 bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
10 var_list += bn_moving_vars
11 saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

這樣就可以在預測時從checkpoint文件加載BN層的參數並設置is_training=False。

最后要說的是,雖然這么做可以解決這個問題,但也可以利用預測數據來計算BN層的參數,不是說一定要保存訓練時的參數,兩種方案可以作為超參數來調節使用,看哪種方法的結果更好。

感謝幾位大佬的博客解惑:

  https://blog.csdn.net/dongjbstrong/article/details/80447110?utm_source=blogxgwz0

  http://www.cnblogs.com/hrlnw/p/7227447.html

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM