Pytorch中的model.train()與model.eval()
最近在跑實驗代碼, 發現對於Pytorch中的model.train()與model.eval()兩種模式的理解只是停留在理論知識的層面,缺少了實操的經驗。下面博主將從理論層面與實驗經驗這兩個方面總結model.train()與model.eval()的區別和坑點。
0. 理論區別
首先需要明確的是這兩個模式會影響Dropout和BatchNormal這兩個Module的行為。
0.1 Dropout
在train模式下,Dropout會根據概率p隨機將部分輸出變為0,即讓一些神經元失活。
在eval模式下, Dropout將不再會將部分輸出變為0,即相當於從模型中移除了該Module。
0.2 BatchNorm
首先,有關BN的理論學習請移步我另一篇博客帶你一文讀懂Batch Normalization。這里我只是討論train與eval模式下BN的行為差異。
首先需要明確BN的行為由training屬性(這里就是通過model.train()設置)和track_running_stats屬性控制。
在BN中track_running_stats屬性默認為True,在train模式下,forward的時候統計running_mean, running_var並將其作為 μ , σ \mu, \sigma μ,σ,其統計公式如下圖所示,在eval模式下,利用前面統計的均值和方差作為 μ , σ \mu, \sigma μ,σ用於inference。其中running_mean, running_var的更新規則如下,可以簡單地理解為計算了多個batch的平均值,要想了解指數平均具體原理可以一步博主另一篇博客一文帶你入門深度學習優化算法
控制邏輯如下:
下面根據train和track_running_stats的兩種狀態兩兩組合為四類,分類討論:
敲重點!!!:
- training=True, track_running_stats=True:即訓練模式下跟蹤運行狀態, 此時前向傳播使用的 σ , μ \sigma, \mu σ,μ為running_mean、running_var。
- training=True, track_running_stats=False:即訓練模式下不跟蹤運行狀態,此時前向傳播使用的 σ , μ \sigma, \mu σ,μ為None,也就是當前batch的均值和方差
- training=False, track_running_stats=True:即評估模式下跟蹤運行狀態,此時前向傳播使用的 σ , μ \sigma, \mu σ,μ為running_mean、running_var。
- training=False, track_running_stats=False:即評估模式下不跟蹤運行狀態,此時前向傳播使用的 σ , μ \sigma, \mu σ,μ為running_mean、running_var。
所以使用當前batch的統計量的情況只有一個,就是訓練模式下不跟蹤運行狀態。
1. 實驗經驗1
1.1 實驗配置
要從實驗經驗上面講兩種模式,就需要從一個具體的實驗開始。首先說一下我的實驗配置:
- 模型中全部使用默認參數創建BN,即此時momentum=0.1,track_running_stats=True,也就是大概跟蹤10個batch的平均值,沒有Dropout
- batch size為12
- 數據集分為train split和test split,訓練在train split上,測試精度在test split之上得到
- 使用了在ImageNet之上pretrain的模型,即此時BN已經有了running_mean和running_var了
- 使用了ImageNet的mean和std對數據進行規范化處理
- 數據集為CVUSA
1.2 實驗現象
最后的實驗結果發現模型在eval模式之下訓練比在train模式之下訓練的測試精度更好
1.3 原因分析
通過上述理論和實驗配置我們可以知道,實驗現象就是在說,訓練中,模型使用ImageNet中的running_mean和running_var作為 μ , σ \mu, \sigma μ,σ的效果會好於在訓練中不斷利用數據集CVUSA去更新的running_mean和running_var作為 μ , σ \mu, \sigma μ,σ的效果。
導致上述現象的原因,中心問題就是train模式下利用CVUSA更新ImageNet估計的running_mean和running_var不夠穩定,所以解決策略需要從穩定running_mean和running_var上入手。
- batch size太小,因為bs太小會讓統計更新的running_mean和running_var中的噪聲比較大。
- momentum太大,可以調小momentum以估計更多的minibatch,即讓running_mean和running_var從更多的minibatch中獲得,從而更加穩定。
1.4 實驗驗證
- 在超算中, 給pretrained的BN重新初始化,再進行試驗。
在這個實驗中效果不太好
- 重新初始化后再調小momentum,即估計更多的minibatch。
在這個實驗中效果不太好
- √ 真實原因是博主做geo-loc實驗,一次前向傳播用了地面視角和空域視角兩張圖片作為一個input pair,而兩種視角圖片的均值與方差差異很大,導致共享參數的孿生神經網絡的BN層無法很好地得到數據集方差與均值的估計值。
2. 實驗經驗2
沒有做過這個實驗,但是看到論壇中在討論有關問題,這里就把它放出來。
2.1 實驗現象
訓練模型之后,開train模式進行驗證比開eval模式進行驗證精度更高。
2.2 原因分析
將上述實驗現象翻譯過來就是,模型開train模式進行評估,即用當前batch的均值與方差去繼續更新running_mean,running_var作為 μ , σ \mu, \sigma μ,σ,這樣的效果反而比開eval模式固定統計出來的running_mean,running_var作為 μ , σ \mu, \sigma μ,σ更好。導致上述問題的原因可能有以下這些:
- batchsize較小導致的估計不穩定
- momentum較大導致的估計不穩定
- 訓練集分布與測試集分布不太一致(一般會出現在那些網絡收集不出名的小數據集上面,比較經典的數據集應該不太會是這個問題)、
遇到上述問題的時候一般要先從原因下手分析再解決,但是如果實在難以定位問題(因為雖然底層邏輯的原因一樣,但是每個不同的task為什么會導致這樣的原因,還是由具體的task決定的),所以博主也提供了以下兩種快速解決方案:
- √ 快速解決方案1:把BN層的track_running_stats屬性設置為False再重新訓練,這樣可以不追蹤整個數據集的統計量,而是直接利用當前batch的計算的 μ , σ \mu, \sigma μ,σ (不過個人不是很建議這么做, bs=1的時候不能估計當前batch的均值和方差)
- √ 快速解決方案2:直接去掉BN和dropout,不過這樣可能會掉一些點
PS:雖然快速解決方案能快速奏效,但是還是有限制條件的,所以建議還是利用博主提供的分析范式,再結合不同具體的task進行分析,從根本上找到問題,然后再解決。