pytorch 中 torch.no_grad()、requires_grad、eval()


requires_grad

requires_grad=True 要求計算梯度;
requires_grad=False 不要求計算梯度;
在pytorch中,tensor有一個 requires_grad參數,如果設置為True,則反向傳播時,該tensor就會自動求導。 tensor的requires_grad的屬性默認為False,若一個節點(葉子變量:自己創建的tensor)requires_grad被設置為True,那么 所有依賴它的節點requires_grad都為True (即使其他相依賴的tensor的requires_grad = False)

x = torch.randn(10, 5, requires_grad = True)
y = torch.randn(10, 5, requires_grad = False)
z = torch.randn(10, 5, requires_grad = False)
w = x + y + z
w.requires_grad

輸出:

True

volatile

volatile是Variable的另一個重要的標識,它能夠將所有依賴它的節點全部設為volatile=True,優先級比requires_grad=True高。
而volatile=True的節點不會求導,即使requires_grad=True,也不會進行反向傳播,對於不需要反向傳播的情景(inference,測試階段推斷階段),該參數可以實現一定速度的提升,並節省一半的顯存,因為其不需要保存梯度。
但是, 注意 volatile已經取消了,使用with torch.no_grad()來替代

torch.no_grad()

是一個上下文管理器,被該語句 wrap 起來的部分將不會track 梯度。
with torch.no_grad()或者@torch.no_grad()中的數據不需要計算梯度,也不會進行反向傳播。
(torch.no_grad()是新版本pytorch中volatile的替代)

x = torch.randn(2, 3, requires_grad = True)
y = torch.randn(2, 3, requires_grad = False)
z = torch.randn(2, 3, requires_grad = False)
m=x+y+z
with torch.no_grad():
    w = x + y + z
print(w)
print(m)
print(w.requires_grad)
print(w.grad_fn)
print(w.requires_grad)

輸出:

tensor([[-2.7066, -0.7406,  0.5740],
        [-0.7071, -1.6057,  1.9732]])
tensor([[-2.7066, -0.7406,  0.5740],
        [-0.7071, -1.6057,  1.9732]], grad_fn=<AddBackward0>)
False
None
False

model.eval()與with torch.no_grad()

共同點:

在PyTorch中進行validation時,使用這兩者均可切換到測試模式。

如用於通知dropout層和batchnorm層在train和val模式間切換。
在train模式下,dropout網絡層會按照設定的參數p設置保留激活單元的概率(保留概率=p); batchnorm層會繼續計算數據的mean和var等參數並更新。
在val模式下,dropout層會讓所有的激活單元都通過,而batchnorm層會停止計算和更新mean和var,直接使用在訓練階段已經學出的mean和var值。

不同點:

model.eval()會影響各層的gradient計算行為,即gradient計算和存儲與training模式一樣,只是不進行反傳。

with torch.zero_grad()則停止autograd模塊的工作,也就是停止gradient計算,以起到加速和節省顯存的作用,從而節省了GPU算力和顯存,但是並不會影響dropout和batchnorm層的行為。

也就是說,如果不在意顯存大小和計算時間的話,僅使用model.eval()已足夠得到正確的validation的結果;而with torch.zero_grad()則是更進一步加速和節省gpu空間(因為不用計算和存儲gradient),從而可以更快計算,也可以跑更大的batch來測試。

參考

1.https://www.jianshu.com/p/1cea017f5d11
2.csdn博客


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM