Pytorch中的BatchNorm的API主要有:
1 torch.nn.BatchNorm1d(num_features, 2 3 eps=1e-05, 4 5 momentum=0.1, 6 7 affine=True, 8 9 track_running_stats=True)
一般來說pytorch中的模型都是繼承nn.Module
類的,都有一個屬性trainning
指定是否是訓練狀態,訓練狀態與否將會影響到某些層的參數是否是固定的,比如BN層或者Dropout層。通常用model.train()
指定當前模型model
為訓練狀態,model.eval()
指定當前模型為測試狀態。
同時,BN的API中有幾個參數需要比較關心的,一個是affine
指定是否需要仿射,還有個是track_running_stats
指定是否跟蹤當前batch的統計特性。容易出現問題也正好是這三個參數:trainning
,affine
,track_running_stats
。
- 其中的
affine
指定是否需要仿射,也就是是否需要上面算式的第四個,如果affine=False
則γγ=1,β=0 \gamma=1,\beta=0γ=1,β=0,並且不能學習被更新。一般都會設置成affine=True
[10] trainning
和track_running_stats
,track_running_stats=True
表示跟蹤整個訓練過程中的batch的統計特性,得到方差和均值,而不只是僅僅依賴與當前輸入的batch的統計特性。相反的,如果track_running_stats=False
那么就只是計算當前輸入的batch的統計特性中的均值和方差了。當在推理階段的時候,如果track_running_stats=False
,此時如果batch_size
比較小,那么其統計特性就會和全局統計特性有着較大偏差,可能導致糟糕的效果。
一般來說,trainning
和track_running_stats
有四種組合[7]
trainning=True
,track_running_stats=True
。這個是期望中的訓練階段的設置,此時BN將會跟蹤整個訓練過程中batch的統計特性。trainning=True
,track_running_stats=False
。此時BN只會計算當前輸入的訓練batch的統計特性,可能沒法很好地描述全局的數據統計特性。trainning=False
,track_running_stats=True
。這個是期望中的測試階段的設置,此時BN會用之前訓練好的模型中的(假設已經保存下了)running_mean
和running_var
並且不會對其進行更新。一般來說,只需要設置model.eval()
其中model
中含有BN層,即可實現這個功能。[6,8]trainning=False
,track_running_stats=False
效果同(2),只不過是位於測試狀態,這個一般不采用,這個只是用測試輸入的batch的統計特性,容易造成統計特性的偏移,導致糟糕效果。
同時,我們要注意到,BN層中的running_mean
和running_var
的更新是在forward()
操作中進行的,而不是optimizer.step()
中進行的,因此如果處於訓練狀態,就算你不進行手動step()
,BN的統計特性也會變化的。如
1 model.train() # 處於訓練狀態 2 3 4 for data, label in self.dataloader: 5 6 pred = model(data) 7 8 # 在這里就會更新model中的BN的統計特性參數,running_mean, running_var 9 10 loss = self.loss(pred, label) 11 12 # 就算不要下列三行代碼,BN的統計特性參數也會變化 13 14 opt.zero_grad() 15 16 loss.backward() 17 18 opt.step()
這個時候要將model.eval()
轉到測試階段,才能固定住running_mean
和running_var
。有時候如果是先預訓練模型然后加載模型,重新跑測試的時候結果不同,有一點性能上的損失,這個時候十有八九是trainning
和track_running_stats
設置的不對,這里需要多注意。 [8]
Reference
[1]. 用pytorch踩過的坑
[2]. Ioffe S, Szegedy C. Batch normalization: accelerating deep network training by reducing internal covariate shift[C]// International Conference on International Conference on Machine Learning. JMLR.org, 2015:448-456.
[3]. <深度學習優化策略-1>Batch Normalization(BN)
[4]. 詳解深度學習中的Normalization,BN/LN/WN
[5]. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L23-L24
[6]. https://discuss.pytorch.org/t/what-is-the-running-mean-of-batchnorm-if-gradients-are-accumulated/18870
[7]. BatchNorm2d增加的參數track_running_stats如何理解?
[8]. Why track_running_stats is not set to False during eval
[9]. How to train with frozen BatchNorm?
[10]. Proper way of fixing batchnorm layers during training
[11]. 大白話《Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift》