tensor.detach() 和 tensor.data 的區別


 detach()和data生成的都是無梯度的純tensor,並且通過同一個tensor數據操作,是共享一塊數據內存。

1 import torch
2 t1 = torch.tensor([0,1.],requires_grad=True)
3 t2=t1.detach()
4 t3=t1.data
5 print(t2.requires_grad,t3.requires_grad)
6 ---------------------------------------------
7 output: False, False

 x.data和x.detach()新分離出來的tensor的requires_grad=False,即不可求導時兩者之間沒有區別,但是當當requires_grad=True的時候的兩者之間的是有不同:x.data不能被autograd追蹤求微分,但是x.detach可以被autograd()追蹤求導。

 1、x.data

 1 import torch
 2 a = torch.tensor([1,2,3.], requires_grad=True)
 3 out = a.sigmoid()
 4 out 
 5 ----------------------------------------------------
 6 output: tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
 7 
 8 c = out.data
 9 c
10 -----------------------------------------------------
11 output: tensor([0.7311, 0.8808, 0.9526])
12 
13 c.zero_()    # 歸0化
14 out
15 ------------------------------------------------------
16 tensor([0., 0., 0.], grad_fn=<SigmoidBackward>)
17 
18 out.sum().backward()
19 a.grad
20 -------------------------------------------------------
21 output:tensor([0., 0., 0.])

 2、x.detach()

 1 b = torch.tensor([1,2,3.], requires_grad=True)
 2 out1 = b.sigmoid()
 3 out1
 4 ------------------------------------------------------
 5 output:tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
 6 
 7 c1 = out1.detach()
 8 c1
 9 ------------------------------------------------------
10 output:tensor([0.7311, 0.8808, 0.9526])
11 
12 c1.zero_()
13 out1.sum().backward()   # 報錯是是因為autograd追蹤求導的時候發現數據已經發生改變,被覆蓋。
14 -------------------------------------------------------
15 output: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation:

 總結:

  x.data和x.detach()都是從原有計算中分離出來的一個tensor變量 ,並且都是inplace operation.在進行autograd追蹤求倒時,兩個的常量是相同。

  不同:.data時屬性,detach()是方法。 x.data不是安全的,x.detach()是安全的。


參考1:tensor中的data()函數與detach()的區別


 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM