https://blog.csdn.net/weixin_39228381/article/details/107896863
目錄
說明
網絡訓練時和網絡評估時,BatchNorm模塊的計算方式不同。如果一個網絡里包含了BatchNorm,則在訓練時需要先調用train(),使網絡里的BatchNorm模塊的training=True(默認是True),在網絡評估時,需要先調用eval()使網絡的training=False。
BatchNorm1d參數
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
num_features
輸入維度是(N, C, L)時,num_features應該取C;這里N是batch size,C是數據的channel,L是數據長度。
輸入維度是(N, L)時,num_features應該取L;這里N是batch size,L是數據長度,這時可以認為每條數據只有一個channel,省略了C
eps
對輸入數據進行歸一化時加在分母上,防止除零,詳情見下文。
momentum
更新全局均值running_mean和方差running_var時使用該值進行平滑,詳情見下文。
affine
設為True時,BatchNorm層才會學習參數
和
,否則不包含這兩個變量,變量名是weight和bias,詳情見下文。
track_running_stats
設為True時,BatchNorm層會統計全局均值running_mean和方差running_var,詳情見下文。
BatchNorm1d訓練時前向傳播
- 首先對輸入batch求
和
,並用這兩個結果把batch歸一化,使其均值為0,方差為1。歸一化公式用到了eps(
),即
。如下輸入內容,shape是(3, 4),即batch_size=3,此時num_features需要傳入4。
-
tensor = torch.FloatTensor([[1, 2, 4, 1],
-
[ 6, 3, 2, 4],
-
[ 2, 4, 6, 1]])
,
(無偏樣本方差)和
(有偏樣本方差),有偏和無偏的區別在於無偏的分母是N-1,有偏的分母是N。注意在BatchNorm中,用於更新running_var時,使用無偏樣本方差即,但是在對batch進行歸一化時,使用有偏樣本方差,因此如果batch_size=1,會報錯。歸一化后的內容如下。
-
[[ -0.9258, -1.2247, 0.0000, -0.7071],
-
[ 1.3887, 0.0000, -1.2247, 1.4142],
-
[ -0.4629, 1.2247, 1.2247, -0.7071]]
-
- 如果track_running_stats==True,則使用momentum更新模塊內部的running_mean(初值是[0., 0., 0., 0.])和running_var(初值是[1., 1., 1., 1.]),更新公式是
,其中
代表更新后的running_mean和running_var,
表示更新前的running_mean和running_var,
表示當前batch的均值和無偏樣本方差。 - 如果track_running_stats==False,則BatchNorm中不含有running_mean和running_var兩個變量。
- 如果affine==True,則對歸一化后的batch進行仿射變換,即乘以模塊內部的weight(初值是[1., 1., 1., 1.])然后加上模塊內部的bias(初值是[0., 0., 0., 0.]),這兩個變量會在反向傳播時得到更新。
- 如果affine==False,則BatchNorm中不含有weight和bias兩個變量,什么都都不做。
BatchNorm1d評估時前向傳播
- 如果track_running_stats==True,則對batch進行歸一化,公式為
,注意這里的均值和方差是running_mean和running_var,在網絡訓練時統計出來的全局均值和無偏樣本方差。 - 如果track_running_stats==False,則對batch進行歸一化,公式為
,注意這里的均值和方差是batch自己的mean和var,此時BatchNorm里不含有running_mean和running_var。注意此時使用的是無偏樣本方差(和訓練時不同),因此如果batch_size=1,會使分母為0,就報錯了。 - 如果affine==True,則對歸一化后的batch進行放射變換,即乘以模塊內部的weight然后加上模塊內部的bias,這兩個變量都是網絡訓練時學習到的。
- 如果affine==False,則BatchNorm中不含有weight和bias兩個變量,什么都不做。
總結
在使用batchNorm時,通常只需要指定num_features就可以了。網絡訓練前調用train(),訓練時BatchNorm模塊會統計全局running_mean和running_var,學習weight和bias,即文獻中的
和
。網絡評估前調用eval(),評估時,對傳入的batch,使用統計的全局running_mean和running_var對batch進行歸一化,然后使用學習到的weight和bias進行仿射變換。
