Pytorch中的BatchNorm


轉自:https://blog.csdn.net/LoseInVain/article/details/86476010

前言
本文主要介紹在pytorch中的Batch Normalization的使用以及在其中容易出現的各種小問題,本來此文應該歸屬於[1]中的,但是考慮到此文的篇幅可能會比較大,因此獨立成篇,希望能夠幫助到各位讀者。如有謬誤,請聯系指出,如需轉載,請注明出處,謝謝。
\nabla 聯系方式:
e-mail: FesianXu@163.com
QQ: 973926198
github: https://github.com/FesianXu


Batch Normalization,批規范化

Batch Normalization(簡稱為BN)[2],中文翻譯成批規范化,是在深度學習中普遍使用的一種技術,通常用於解決多層神經網絡中間層的協方差偏移(Internal Covariate Shift)問題,類似於網絡輸入進行零均值化和方差歸一化的操作,不過是在中間層的輸入中操作而已,具體原理不累述了,見[2-4]的描述即可。

在BN操作中,最重要的無非是這四個式子:
Unexpected text node: ' '
注意到這里的最后一步也稱之為仿射(affine),引入這一步的目的主要是設計一個通道,使得輸出output至少能夠回到輸入input的狀態(當 γ = 1 , β = 0 \gamma=1,\beta=0 時)使得BN的引入至少不至於降低模型的表現,這是深度網絡設計的一個套路。
整個過程見流程圖,BN在輸入后插入,BN的輸出作為規范后的結果輸入的后層網絡中。

forward
backward
forward
backward
input batch
Batch_Norm
Output batch

好了,這里我們記住了,在BN中,一共有這四個參數我們要考慮的:

  • γ , β \gamma, \beta :分別是仿射中的 w e i g h t \mathrm{weight} b i a s \mathrm{bias} ,在pytorch中用weightbias表示。
  • μ B \mu_{\mathcal{B}} σ B 2 \sigma_{\mathcal{B}}^2 :和上面的參數不同,這兩個是根據輸入的batch的統計特性計算的,嚴格來說不算是“學習”到的參數,不過對於整個計算是很重要的。在pytorch中,用running_meanrunning_var表示[5]

在Pytorch中使用

Pytorch中的BatchNorm的API主要有:

torch.nn.BatchNorm1d(num_features, 
                     eps=1e-05, 
                     momentum=0.1, 
                     affine=True, 
                     track_running_stats=True)

   
   
   
           
  • 1
  • 2
  • 3
  • 4
  • 5

一般來說pytorch中的模型都是繼承nn.Module類的,都有一個屬性trainning指定是否是訓練狀態,訓練狀態與否將會影響到某些層的參數是否是固定的,比如BN層或者Dropout層。通常用model.train()指定當前模型model為訓練狀態,model.eval()指定當前模型為測試狀態。
同時,BN的API中有幾個參數需要比較關心的,一個是affine指定是否需要仿射,還有個是track_running_stats指定是否跟蹤當前batch的統計特性。容易出現問題也正好是這三個參數:trainningaffinetrack_running_stats

  • 其中的affine指定是否需要仿射,也就是是否需要上面算式的第四個,如果affine=False γ = 1 , β = 0 \gamma=1,\beta=0 ,並且不能學習被更新。一般都會設置成affine=True[10]
  • trainningtrack_running_statstrack_running_stats=True表示跟蹤整個訓練過程中的batch的統計特性,得到方差和均值,而不只是僅僅依賴與當前輸入的batch的統計特性。相反的,如果track_running_stats=False那么就只是計算當前輸入的batch的統計特性中的均值和方差了。當在推理階段的時候,如果track_running_stats=False,此時如果batch_size比較小,那么其統計特性就會和全局統計特性有着較大偏差,可能導致糟糕的效果。

一般來說,trainningtrack_running_stats有四種組合[7]

  1. trainning=True, track_running_stats=True。這個是期望中的訓練階段的設置,此時BN將會跟蹤整個訓練過程中batch的統計特性。
  2. trainning=True, track_running_stats=False。此時BN只會計算當前輸入的訓練batch的統計特性,可能沒法很好地描述全局的數據統計特性。
  3. trainning=False, track_running_stats=True。這個是期望中的測試階段的設置,此時BN會用之前訓練好的模型中的(假設已經保存下了)running_meanrunning_var並且不會對其進行更新。一般來說,只需要設置model.eval()其中model中含有BN層,即可實現這個功能。[6,8]
  4. trainning=False, track_running_stats=False 效果同(2),只不過是位於測試狀態,這個一般不采用,這個只是用測試輸入的batch的統計特性,容易造成統計特性的偏移,導致糟糕效果。

同時,我們要注意到,BN層中的running_meanrunning_var的更新是在forward()操作中進行的,而不是optimizer.step()中進行的,因此如果處於訓練狀態,就算你不進行手動step(),BN的統計特性也會變化的。如

model.train() # 處於訓練狀態

for data, label in self.dataloader:
pred = model(data)
# 在這里就會更新model中的BN的統計特性參數,running_mean, running_var
loss = self.loss(pred, label)
# 就算不要下列三行代碼,BN的統計特性參數也會變化
opt.zero_grad()
loss.backward()
opt.step()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

這個時候要將model.eval()轉到測試階段,才能固定住running_meanrunning_var。有時候如果是先預訓練模型然后加載模型,重新跑測試的時候結果不同,有一點性能上的損失,這個時候十有八九是trainningtrack_running_stats設置的不對,這里需要多注意。 [8]

假設一個場景,如下圖所示:

input
model_A
model_B
output

此時為了收斂容易控制,先預訓練好模型model_A,並且model_A內含有若干BN層,后續需要將model_A作為一個inference推理模型和model_B聯合訓練,此時就希望model_A中的BN的統計特性值running_meanrunning_var不會亂變化,因此就必須將model_A.eval()設置到測試模式,否則在trainning模式下,就算是不去更新該模型的參數,其BN都會改變的,這個將會導致和預期不同的結果。


Reference

[1]. 用pytorch踩過的坑
[2]. Ioffe S, Szegedy C. Batch normalization: accelerating deep network training by reducing internal covariate shift[C]// International Conference on International Conference on Machine Learning. JMLR.org, 2015:448-456.
[3]. <深度學習優化策略-1>Batch Normalization(BN)
[4]. 詳解深度學習中的Normalization,BN/LN/WN
[5]. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L23-L24
[6]. https://discuss.pytorch.org/t/what-is-the-running-mean-of-batchnorm-if-gradients-are-accumulated/18870
[7]. BatchNorm2d增加的參數track_running_stats如何理解?
[8]. Why track_running_stats is not set to False during eval
[9]. How to train with frozen BatchNorm?
[10]. Proper way of fixing batchnorm layers during training
[11]. 大白話《Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift》

      </div>


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM