Pytorch 中 model.eval() 和 with torch.no_grad() 的區別


model.eval()和with torch.no_grad()的區別
在PyTorch中進行validation時,會使用model.eval()切換到測試模式,在該模式下,

主要用於通知dropout層和batchnorm層在train和val模式間切換
在train模式下,dropout網絡層會按照設定的參數p設置保留激活單元的概率(保留概率=p); batchnorm層會繼續計算數據的mean和var等參數並更新。
在val模式下,dropout層會讓所有的激活單元都通過,而batchnorm層會停止計算和更新mean和var,直接使用在訓練階段已經學出的mean和var值。
該模式不會影響各層的gradient計算行為,即gradient計算和存儲與training模式一樣,只是不進行反傳(backprobagation)
而with torch.no_grad()則主要是用於停止autograd模塊的工作,以起到加速和節省顯存的作用,具體行為就是停止gradient計算,從而節省了GPU算力和顯存,但是並不會影響dropout和batchnorm層的行為。
使用場景
如果不在意顯存大小和計算時間的話,僅僅使用model.eval()已足夠得到正確的validation的結果;而with torch.zero_grad()則是更進一步加速和節省gpu空間(因為不用計算和存儲gradient),從而可以更快計算,也可以跑更大的batch來測試。

參考
https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615/38
https://ryankresse.com/batchnorm-dropout-and-eval-in-pytorch/
————————————————
版權聲明:本文為CSDN博主「江前雲后」的原創文章,遵循CC 4.0 BY-SA版權協議,轉載請附上原文出處鏈接及本聲明。
原文鏈接:https://blog.csdn.net/songyu0120/article/details/103884586


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM