pytorch autograd backward函數中 retain_graph參數的作用,簡單例子分析,以及create_graph參數的作用


retain_graph參數的作用

官方定義:

retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.

大意是如果設置為False,計算圖中的中間變量在計算完后就會被釋放。但是在平時的使用中這個參數默認都為False從而提高效率,和creat_graph的值一樣。

具體看一個例子理解:

假設一個我們有一個輸入x,y = x **2, z = y*4,然后我們有兩個輸出,一個output_1 = z.mean(),另一個output_2 = z.sum()。然后我們對兩個output執行backward。

 1 import torch
 2 x = torch.randn((1,4),dtype=torch.float32,requires_grad=True) 3 y = x ** 2 4 z = y * 4 5 print(x) 6 print(y) 7 print(z) 8 loss1 = z.mean() 9 loss2 = z.sum() 10 print(loss1,loss2) 11 loss1.backward() # 這個代碼執行正常,但是執行完中間變量都free了,所以下一個出現了問題 12 print(loss1,loss2) 13 loss2.backward() # 這時會引發錯誤

程序正常執行到第12行,所有的變量正常保存。但是在第13行報錯:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

分析:計算節點數值保存了,但是計算圖x-y-z結構被釋放了,而計算loss2的backward仍然試圖利用x-y-z的結構,因此會報錯。

因此需要retain_graph參數為True去保留中間參數從而兩個loss的backward()不會相互影響。正確的代碼應當把第11行以及之后改成

1 # 假如你需要執行兩次backward,先執行第一個的backward,再執行第二個backward
2 loss1.backward(retain_graph=True)# 這里參數表明保留backward后的中間參數。
3 loss2.backward() # 執行完這個后,所有中間變量都會被釋放,以便下一次的循環
4  #如果是在訓練網絡optimizer.step() # 更新參數

create_graph參數比較簡單,參考官方定義:
  • create_graph (booloptional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False.

附參考學習的鏈接如下,並對作者表示感謝:retain_graph參數的作用.

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM