pytorch detach函數


用於截斷反向傳播

detach()源碼:

def detach(self):
    result = NoGrad()(self)  # this is needed, because it merges version counters
    result._grad_fn = None
    return result

它的返回結果與調用者共享一個data tensor,且會將grad_fn設為None,這樣就不知道該Tensor是由什么操作建立的,截斷反向傳播

這個時候再一個tensor使用In_place操作會導致另一個的data tensor也會發生改變

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
out = a.sigmoid()
print(out)#tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)

c = out.detach()
print(c)#tensor([0.7311, 0.8808, 0.9526])

這個時候可以看到,c和out的區別就是一個有grad_fn,一個沒有grad_fn

執行out.sum().backward()沒有問題,但執行c.sum().backward()報錯:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

這個時候不論是對out還是對c進行inplace操作改變它們的data,這個改動會被autograd追蹤,這個時候再執行out.sum().backward()會報錯

假設對out進行inplace操作,會出現:

out.zero_()
#tensor([0., 0., 0.], grad_fn=<ZeroBackward>)

out.sum().backward()
#報錯

錯誤信息為

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3]], which is output 0 of SigmoidBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

如果不對out進行inplace操作而是對c進行inplace操作,結果是一樣的,Out不能再進行反向傳播了

為了解決這種情況,就要對tensor的data操作,使其不被autograd記錄
重新得到一個out,把它的data部分給c

c = out.data
#tensor([0.7311, 0.8808, 0.9526])

out
#tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)

這里可以看到,c中沒有Out中有的grad_fn信息

這回修改c的值,發現out的data值依然改了,但是執行out.sum().backward()不報錯了

 

 

detach_()

def detach_(self):
    """Detaches the Variable from the graph that created it, making it a leaf.
    """
    self._grad_fn = None
    self.requires_grad = False

做了兩件事:1grad_fn設none2requires_grad設false

它不會新生成一個Variable而是使用原來的variable

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM