『PyTorch』第五彈_深入理解autograd_中:Variable梯度探究


查看非葉節點梯度的兩種方法

在反向傳播過程中非葉子節點的導數計算完之后即被清空。若想查看這些變量的梯度,有兩種方法:

  • 使用autograd.grad函數
  • 使用hook

autograd.gradhook方法都是很強大的工具,更詳細的用法參考官方api文檔,這里舉例說明基礎的使用。推薦使用hook方法,但是在實際使用中應盡量避免修改grad的值。

求z對y的導數

x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum()

# hook
# hook沒有返回值,參數是函數,函數的參數是梯度值
def variable_hook(grad):
    print("hook梯度輸出:\r\n",grad)

hook_handle = y.register_hook(variable_hook)         # 注冊hook
z.backward(retain_graph=True)                        # 內置輸出上面的hook
hook_handle.remove()                                 # 釋放

print("autograd.grad輸出:\r\n",t.autograd.grad(z,y)) # t.autograd.grad方法
hook梯度輸出:
 Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]

autograd.grad輸出:
 (Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]
,)

 

多次反向傳播試驗

實際就是使用retain_graph參數,

# 構件圖
x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum()

z.backward(retain_graph=True)
print(w.grad)
z.backward()
print(w.grad)
Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]

Variable containing:
 2
 2
 2
[torch.FloatTensor of size 3]

 

如果不使用retain_graph參數,

實際上效果是一樣的,AccumulateGrad object仍然會積累梯度

# 構件圖
x = V(t.ones(3))
w = V(t.rand(3),requires_grad=True)
y = w.mul(x)
z = y.sum()

z.backward()
print(w.grad)
y = w.mul(x)  # <-----
z = y.sum()  # <-----
z.backward()
print(w.grad)
Variable containing:
 1
 1
 1
[torch.FloatTensor of size 3]

Variable containing:
 2
 2
 2
[torch.FloatTensor of size 3]

 

分析:

這里的重新建立高級節點意義在這里:實際上高級節點在創建時,會緩存用於輸入的低級節點的信息(值,用於梯度計算),但是這些buffer在backward之后會被清空(推測是節省內存),而這個buffer實際也體現了上面說的動態圖的"動態"過程,之后的反向傳播需要的數據被清空,則會報錯,這樣我們上面過程就分別從:保留數據不被刪除&重建數據兩個角度實現了多次backward過程。

實際上第二次的z.backward()已經不是第一次的z所在的圖了,體現了動態圖的技術,靜態圖初始化之后會留在內存中等待feed數據,但是動態圖不會,動態圖更類似我們自己實現的機器學習框架實踐,相較於靜態邏輯簡單一點,只是PyTorch的靜態圖和我們的比會在反向傳播后清空存下的數據:下次要么完全重建,要么反向傳播之后指定不舍棄圖z.backward(retain_graph=True)。

總之圖上的節點是依賴buffer記錄來完成反向傳播,TensorFlow中會一直存留,PyTorch中就會backward后直接舍棄(默認時)。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM