pytorch:對比clone、detach以及copy_等張量復制操作


文章轉載於:https://blog.csdn.net/guofei_fly/article/details/104486708

pytorch提供了clonedetachcopy_new_tensor等多種張量的復制操作,尤其前兩者在深度學習的網絡架構中經常被使用,本文旨在對比這些操作的差別。

1. clone

返回一個和源張量同shapedtypedevice的張量,與源張量不共享數據內存,但提供梯度的回溯

下面,通過例子來詳細說明:

示例

(1)定義

import torch
a = torch.tensor(1.0, requires_grad=True, device="cuda", dtype=torch.float64)
a_ = a.clone()
print(a_)   # tensor(1., device='cuda:0', dtype=torch.float64, grad_fn=<CloneBackward>)

  
  
  
          

注意grad_fn=<CloneBackward>,說明clone后的返回值是個中間variable,因此支持梯度的回溯。因此,clone操作在一定程度上可以視為是一個identity-mapping函數。

(2)梯度的回溯

clone作為一個中間variable,會將梯度傳給源張量進行疊加。

import torch
a = torch.tensor(1.0, requires_grad=True)
y = a ** 2 
a_ = a.clone()
z = a_ * 3
y.backward()
print(a.grad)   # 2
z.backward()
print(a_.grad)   # None. 中間variable,無grad
print(a.grad)    # 5. a_的梯度會傳遞回給a,因此2+3=5

  
  
  
          

但若源張量的require_grad=False,而clone后的張量require_grad=True,顯然此時不存在張量回溯現象,clone后的張量可以求導。

import torch
a = torch.tensor(1.0)
a_ = a.clone()
a_.requires_grad_()
y = a_ ** 2
y.backward()
print(a.grad)   # None
print(a_.grad)   # 2. 可得到導數

  
  
  
          

(3)張量數據非共享

import torch
a = torch.tensor(1.0, requires_grad=True)
a_ = a.clone()
a.data *= 3
a_ += 1
print(a)   # tensor(3., requires_grad=True)
print(a_)  # tensor(2., grad_fn=<AddBackward0>). 注意grad_fn的變化

  
  
  
          

綜上論述,clone操作在不共享數據內存的同時支持梯度回溯,所以常用在神經網絡中某個單元需要重復使用的場景下。

2. detach

detach的機制則與clone完全不同,即返回一個和源張量同shapedtypedevice的張量,與源張量共享數據內存,但不提供梯度計算,即requires_grad=False,因此脫離計算圖。

同樣,通過例子來詳細說明:

(1)定義

import torch
a = torch.tensor(1.0, requires_grad=True, device="cuda", dtype=torch.float64)
a_ = a.detach()
print(a_)   # tensor(1., device='cuda:0', dtype=torch.float64)

  
  
  
          

(2)脫離原計算圖

import torch
a = torch.tensor(1.0, requires_grad=True)
y = a ** 2 
a_ = a.detach()
print(a_.grad)    # None,requires_grad=False
a_.requires_grad_()  # 強制其requires_grad=True,從而支持求導
z = a_ * 3
y.backward()
z.backward()
print(a.grad)    # 2,與a_無關系
print(a_.grad)   #

  
  
  
          

可見,detach后的張量,即使重新定義requires_grad=True,也與源張量的梯度沒有關系。

(3)共享張量數據內存

import torch
a = torch.tensor(1.0, requires_grad=True)
a_ = a.detach()
print(a)    # tensor(1., requires_grad=True)
print(a_)   # tensor(1.)
a_ += 1   
print(a)     # tensor(2., requires_grad=True)
print(a_)    # tensor(2.)
a.data *= 2
print(a)    # tensor(4., requires_grad=True)
print(a_)    # tensor(4.)

  
  
  
          

綜上論述,detach操作在共享數據內存的脫離計算圖,所以常用在神經網絡中僅要利用張量數值,而不需要追蹤導數的場景下。

3. clone和detach聯合使用

clone提供了非數據共享的梯度追溯功能,而detach又“舍棄”了梯度功能,因此clonedetach意味着着只做簡單的數據復制,既不數據共享,也不對梯度共享,從此兩個張量無關聯。

置於是先clone還是先detach,其返回值一樣,一般采用tensor.clone().detach()

4. new_tensor

new_tensor可以將源張量中的數據復制到目標張量(數據不共享),同時提供了更細致的devicedtyperequires_grad屬性控制:

new_tensor(data, dtype=None, device=None, requires_grad=False) 

  
  
  
          

注意:其默認參數下的操作等同於.clone().detach(),而requires_grad=True時的效果相當於.clone().detach()requires_grad_(True)。上面兩種情況都推薦使用后者。

5. copy_

copy_同樣將源張量中的數據復制到目標張量(數據不共享),其devicedtyperequires_grad一般都保留目標張量的設定,僅僅進行數據復制,同時其支持broadcast操作。

a = torch.tensor([[1,2,3], [4,5,6]], device="cuda")
b = torch.tensor([7.0,8.0,9.0], requires_grad=True)
a.copy_(b)
print(a)   # tensor([[7, 8, 9], [7, 8, 9]], device='cuda:0') 

  
  
  
          

【Ref】:

  1. 關於 pytorch inplace operation, 需要知道的幾件事


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM