tensorflow中batch normalization的用法


網上找了下tensorflow中使用batch normalization的博客,發現寫的都不是很好,在此總結下:

1.原理

公式如下:

y=γ(x-μ)/σ+β

其中x是輸入,y是輸出,μ是均值,σ是方差,γ和β是縮放(scale)、偏移(offset)系數。

一般來講,這些參數都是基於channel來做的,比如輸入x是一個16*32*32*128(NWHC格式)的feature map,那么上述參數都是128維的向量。其中γ和β是可有可無的,有的話,就是一個可以學習的參數(參與前向后向),沒有的話,就簡化成y=(x-μ)/σ。而μ和σ,在訓練的時候,使用的是batch內的統計值,測試/預測的時候,采用的是訓練時計算出的滑動平均值。

 

2.tensorflow中使用

tensorflow中batch normalization的實現主要有下面三個:

tf.nn.batch_normalization

tf.layers.batch_normalization

tf.contrib.layers.batch_norm

封裝程度逐個遞進,建議使用tf.layers.batch_normalization或tf.contrib.layers.batch_norm,因為在tensorflow官網的解釋比較詳細。我平時多使用tf.layers.batch_normalization,因此下面的步驟都是基於這個。

 

3.訓練

訓練的時候需要注意兩點,(1)輸入參數training=True,(2)計算loss時,要添加以下代碼(即添加update_ops到最后的train_op中)。這樣才能計算μ和σ的滑動平均(測試時會用到)

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss)

 

4.測試

測試時需要注意一點,輸入參數training=False,其他就沒了

 

5.預測

預測時比較特別,因為這一步一般都是從checkpoint文件中讀取模型參數,然后做預測。一般來說,保存checkpoint的時候,不會把所有模型參數都保存下來,因為一些無關數據會增大模型的尺寸,常見的方法是只保存那些訓練時更新的參數(可訓練參數),如下:

var_list = tf.trainable_variables() saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

 

但使用了batch_normalization,γ和β是可訓練參數沒錯,μ和σ不是,它們僅僅是通過滑動平均計算出的,如果按照上面的方法保存模型,在讀取模型預測時,會報錯找不到μ和σ。更詭異的是,利用tf.moving_average_variables()也沒法獲取bn層中的μ和σ(也可能是我用法不對),不過好在所有的參數都在tf.global_variables()中,因此可以這么寫:

var_list = tf.trainable_variables() g_list = tf.global_variables() bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name] bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name] var_list += bn_moving_vars saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

按照上述寫法,即可把μ和σ保存下來,讀取模型預測時也不會報錯,當然輸入參數training=False還是要的。

注意上面有個不嚴謹的地方,因為我的網絡結構中只有bn層包含moving_mean和moving_variance,因此只根據這兩個字符串做了過濾,如果你的網絡結構中其他層也有這兩個參數,但你不需要保存,建議使用諸如bn/moving_mean的字符串進行過濾。

 

2018.4.22更新

提供一個基於mnist的示例,供大家參考。包含兩個文件,分別用於train/test。注意bn_train.py文件的51-61行,僅保存了網絡中的可訓練變量和bn層利用統計得到的mean和var。注意示例中需要下載mnist數據集,要保持電腦可以聯網。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM