pytorch BatchNorm參數詳解,計算過程


https://blog.csdn.net/weixin_39228381/article/details/107896863

 

目錄

 

說明

BatchNorm1d參數

num_features

eps

momentum

affine

track_running_stats

BatchNorm1d訓練時前向傳播

BatchNorm1d評估時前向傳播

總結


說明

網絡訓練時和網絡評估時,BatchNorm模塊的計算方式不同。如果一個網絡里包含了BatchNorm,則在訓練時需要先調用train(),使網絡里的BatchNorm模塊的training=True(默認是True),在網絡評估時,需要先調用eval()使網絡的training=False。

BatchNorm1d參數

torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

num_features

輸入維度是(N, C, L)時,num_features應該取C;這里N是batch size,C是數據的channel,L是數據長度。

輸入維度是(N, L)時,num_features應該取L;這里N是batch size,L是數據長度,這時可以認為每條數據只有一個channel,省略了C

eps

對輸入數據進行歸一化時加在分母上,防止除零,詳情見下文。

momentum

更新全局均值running_mean和方差running_var時使用該值進行平滑,詳情見下文。

affine

設為True時,BatchNorm層才會學習參數\gamma\beta,否則不包含這兩個變量,變量名是weight和bias,詳情見下文。

track_running_stats

設為True時,BatchNorm層會統計全局均值running_mean和方差running_var,詳情見下文。

BatchNorm1d訓練時前向傳播

  1. 首先對輸入batch求E[x]Var[x],並用這兩個結果把batch歸一化,使其均值為0,方差為1。歸一化公式用到了eps(\epsilon),即y=\frac{x-E[x]}{\sqrt{Var[x]+\epsilon }}。如下輸入內容,shape是(3, 4),即batch_size=3,此時num_features需要傳入4。
    1.  
      tensor = torch.FloatTensor([[1, 2, 4, 1],
    2.  
      [ 6, 3, 2, 4],
    3.  
      [ 2, 4, 6, 1]])
    此時E[x]=[3, 3, 4, 2]Var[y]_{unbiased}=[7, 1, 4, 3](無偏樣本方差)和Var[y]_{biased}=[4.6667, 0.6667, 2.6667, 2.0000](有偏樣本方差),有偏和無偏的區別在於無偏的分母是N-1,有偏的分母是N。注意在BatchNorm中,用於更新running_var時,使用無偏樣本方差即,但是在對batch進行歸一化時,使用有偏樣本方差,因此如果batch_size=1,會報錯。歸一化后的內容如下。
    1.  
      [[ -0.9258, -1.2247, 0.0000, -0.7071],
    2.  
      [ 1.3887, 0.0000, -1.2247, 1.4142],
    3.  
      [ -0.4629, 1.2247, 1.2247, -0.7071]]
  2. 如果track_running_stats==True,則使用momentum更新模塊內部的running_mean(初值是[0., 0., 0., 0.])和running_var(初值是[1., 1., 1., 1.]),更新公式是x_{new}=(1-momentum)\times x_{cur}+momentum\times x_{batch},其中x_{new}代表更新后的running_mean和running_var,x_{cur}表示更新前的running_mean和running_var,x_{batch}表示當前batch的均值和無偏樣本方差。
  3. 如果track_running_stats==False,則BatchNorm中不含有running_mean和running_var兩個變量。
  4. 如果affine==True,則對歸一化后的batch進行仿射變換,即乘以模塊內部的weight(初值是[1., 1., 1., 1.])然后加上模塊內部的bias(初值是[0., 0., 0., 0.]),這兩個變量會在反向傳播時得到更新。
  5. 如果affine==False,則BatchNorm中不含有weight和bias兩個變量,什么都都不做。

BatchNorm1d評估時前向傳播

  1. 如果track_running_stats==True,則對batch進行歸一化,公式為y=\frac{x-\hat{E}[x]}{\sqrt{\hat{Var}[x]+\epsilon }},注意這里的均值和方差是running_mean和running_var,在網絡訓練時統計出來的全局均值和無偏樣本方差。
  2. 如果track_running_stats==False,則對batch進行歸一化,公式為y=\frac{x-{E}[x]}{\sqrt{​{Var}[x]+\epsilon }},注意這里的均值和方差是batch自己的mean和var,此時BatchNorm里不含有running_mean和running_var。注意此時使用的是無偏樣本方差(和訓練時不同),因此如果batch_size=1,會使分母為0,就報錯了。
  3. 如果affine==True,則對歸一化后的batch進行放射變換,即乘以模塊內部的weight然后加上模塊內部的bias,這兩個變量都是網絡訓練時學習到的。
  4. 如果affine==False,則BatchNorm中不含有weight和bias兩個變量,什么都不做。

總結

在使用batchNorm時,通常只需要指定num_features就可以了。網絡訓練前調用train(),訓練時BatchNorm模塊會統計全局running_mean和running_var,學習weight和bias,即文獻中的\gamma\beta。網絡評估前調用eval(),評估時,對傳入的batch,使用統計的全局running_mean和running_var對batch進行歸一化,然后使用學習到的weight和bias進行仿射變換。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM