pytorch中copy() clone() detach()


Torch 為了提高速度,向量或是矩陣的賦值是指向同一內存的
如果需要開辟新的存儲地址而不是引用,可以用clone()進行深拷貝

區別

clone()

解釋說明: 返回一個原張量的副本,同時不破壞計算圖,它能夠維持反向傳播計算梯度,
並且兩個張量不共享內存.一個張量上值的改變不影響另一個張量.

copy_()

解釋說明: 比如x4.copy_(x2), 將x2的數據復制到x4,並且會
修改計算圖,使得反向傳播自動計算梯度時,計算出x4的梯度后
再繼續前向計算x2的梯度. 注意,復制完成之后,兩者的值的改變互不影響,
因為他們並不共享內存.

detach()

解釋說明: 比如x4 = x2.detach(),返回一個和原張量x2共享內存的新張量x4,
兩者的改動可以相互可見, 一個張量上的值的改動會影響到另一個張量.
返回的新張量和計算圖相互獨立,即新張量和計算圖不再關聯,
因此也無法進行反向傳播計算梯度.即從計算圖上把這個張量x2拆
卸detach下來,非常形象.

detach_()

解釋說明: detach()的原地操作版本,功能和detach()類似.
比如x4 = x2.detach_(),其實x2和x4是同一個對象,返回的是self,
x2和x4具有相同的id()值.

copy.copy

參考博客

例子

a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.detach()
print(b)
"""
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
        [4., 5., 6.]])
"""

detach()操作后的tensor與原始tensor共享數據內存,當原始tensor在計算圖中數值發生反向傳播等更新之后,detach()的tensor值也發生了改變

a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.clone()
print(b)
"""
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
        [4., 5., 6.]], grad_fn=<CloneBackward>)
"""

grad_fn=<CloneBackward>表示clone后的返回值是個中間變量,因此支持梯度的回溯。

a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.detach().clone()
print(b)
"""
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
        [4., 5., 6.]])
"""
a = torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(a)
b=a.detach().clone().requires_grad_(True)
print(b)
"""
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
"""

clone()操作后的tensor requires_grad=True
detach()操作后的tensor requires_grad=False

import torch
torch.manual_seed(0)

x= torch.tensor([1., 2.], requires_grad=True)
clone_x = x.clone() 
detach_x = x.detach()
clone_detach_x = x.clone().detach() 

f = torch.nn.Linear(2, 1)
y = f(x)
y.backward()

print(x.grad)
print(clone_x.requires_grad)
print(clone_x.grad)
print(detach_x.requires_grad)
print(clone_detach_x.requires_grad)
'''
輸出結果如下:
tensor([-0.0053,  0.3793])
True
None
False
False
'''


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM