pytorch 查看中間變量的梯度


pytorch 為了節省顯存,在反向傳播的過程中只針對計算圖中的葉子結點(leaf variable)保留了梯度值(gradient)。但對於開發者來說,有時我們希望探測某些中間變量(intermediate variable) 的梯度來驗證我們的實現是否有誤,這個過程就需要用到 tensor的register_hook接口。一段簡單的示例代碼如下,代碼主要來自pytorch開發者的回答,筆者稍作修改使其更符合最新版的pytorch 語法(v1.2.0)。

grads = {}

def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook

x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y

# 為中間變量注冊梯度保存接口,存儲梯度時名字為 y。
y.register_hook(save_grad('y'))

# 反向傳播 
z.backward()

# 查看 y 的梯度值
print(grads['y'])

一個示例輸出是:

tensor([-1.5435])


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM