pytorch requires_grad = True的意思


計算圖通常包含兩種元素,一個是 tensor,另一個是 Function。張量 tensor 不必多說,但是大家可能對 Function 比較陌生。這里 Function 指的是在計算圖中某個節點(node)所進行的運算,比如加減乘除卷積等等之類的,Function 內部有 forward() 和 backward() 兩個方法,分別應用於正向、反向傳播。

當我們創建一個張量 (tensor) 的時候,如果沒有特殊指定的話,那么這個張量是默認是不需要求導的。們在訓練一個網絡的時候,我們從 DataLoader 中讀取出來的一個 mini-batch 的數據,這些輸入默認是不需要求導的,其次,網絡的輸出我們沒有特意指明需要求導吧,Ground Truth 我們也沒有特意設置需要求導吧。這么一想,哇,那我之前的那些 loss 咋還能自動求導呢?其實原因就是上邊那條規則,雖然輸入的訓練數據是默認不求導的,但是,我們的 model 中的所有參數,它默認是求導的,這么一來,其中只要有一個需要求導,那么輸出的網絡結果必定也會需要求的。來看個實例:

input = torch.randn(8, 3, 50, 100)
print(input.requires_grad)
# False

net = nn.Sequential(nn.Conv2d(3, 16, 3, 1),
                    nn.Conv2d(16, 32, 3, 1))
for param in net.named_parameters():
    print(param[0], param[1].requires_grad)
# 0.weight True
# 0.bias True
# 1.weight True
# 1.bias True

output = net(input)
print(output.requires_grad)
# True

在寫代碼的過程中,不要把網絡的輸入和 Ground Truth 的 requires_grad 設置為 True。雖然這樣設置不會影響反向傳播,但是需要額外計算網絡的輸入和 Ground Truth 的導數,增大了計算量和內存占用不說,這些計算出來的導數結果也沒啥用。因為我們只需要神經網絡中的參數的導數,用來更新網絡,其余的導數都不需要。

 

原文鏈接:https://zhuanlan.zhihu.com/p/67184419


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2026 CODEPRJ.COM