PyTorch grad_fn的作用以及RepeatBackward, SliceBackward示例


變量.grad_fn表明該變量是怎么來的,用於指導反向傳播。例如loss = a+b,則loss.gard_fn為<AddBackward0 at 0x7f2c90393748>,表明loss是由相加得來的,這個grad_fn可指導怎么求a和b的導數

程序示例:

import torch

w1 = torch.tensor(2.0, requires_grad=True)
a = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)
tmp = a[0, :]
tmp.retain_grad()   # tmp是非葉子張量,需用.retain_grad()方法保留導數,否則導數將會在反向傳播完成之后被釋放掉
b = tmp.repeat([3, 1])
b.retain_grad()
loss = (b * w1).mean()
loss.backward()

print(b.grad_fn)    # 輸出: <RepeatBackward object at 0x7f2c903a10f0>
print(b.grad)       # 輸出: tensor([[0.3333, 0.3333],
                    #               [0.3333, 0.3333],
                    #               [0.3333, 0.3333]])

print(tmp.grad_fn)    # 輸出:<SliceBackward object at 0x7f2c90393f60>
print(tmp.grad)       # 輸出:tensor([1., 1.])


print(a.grad)     # 輸出:tensor([[1., 1.],
                  #              [0., 0.]])

手動推導:

手動推導的結果和程序的結果是一致的。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM