pytorch反向傳播兩次,梯度相加,retain_graph=True


pytorch是動態圖計算機制,也就是說,每次正向傳播時,pytorch會搭建一個計算圖,loss.backward()之后,這個計算圖的緩存會被釋放掉,下一次正向傳播時,pytorch會重新搭建一個計算圖,如此循環。

在默認情況下,PyTorch每一次搭建的計算圖只允許一次反向傳播,如果要進行兩次反向傳播,則需要在第一次反向傳播時設置retain_graph=True,即 loss.backwad(retain_graph=True) ,這樣做可以保留動態計算圖,在第二次反向傳播時,將自動和第一次的梯度相加。

示例:

import torch

input_ = torch.tensor([[1., 2.], [3., 4.]], requires_grad=False)
w1 = torch.tensor(2.0, requires_grad=True)
w2 = torch.tensor(3.0, requires_grad=True)

l1 = input_ * w1
l2 = l1 + w2
loss1 = l2.mean()
loss1.backward(retain_graph=True)

print(w1.grad)  # 輸出:tensor(2.5)
print(w2.grad)  # 輸出:tensor(1.)

loss2 = l2.sum()
loss2.backward()

print(w1.grad)  # 輸出:tensor(12.5)
print(w2.grad)  # 輸出:tensor(5.)

示例中的梯度推導很簡單,我在這篇博客里推了一下。從輸出結果來看,程序確實是把兩次的梯度加起來了。

附注:如果網絡要進行兩次反向傳播,卻沒有用retain_graph=True,則運行時會報錯:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM