PyTorch 兩大轉置函數 transpose() 和 permute(),


pytorch中轉置用的函數就只有這兩個

  1. transpose()
  2. permute()

 

transpose()
torch.transpose(input, dim0, dim1, out=None) → Tensor

函數返回輸入矩陣input的轉置。交換維度dim0dim1

參數:

  • input (Tensor) – 輸入張量,必填
  • dim0 (int) – 轉置的第一維,默認0,可選
  • dim1 (int) – 轉置的第二維,默認1,可選

注意只能有兩個相關的交換的位置參數。

permute()

 

參數:

dims (int…*)-換位順序,必填

相同點

  1. 都是返回轉置后矩陣。
  2. 都可以操作高緯矩陣,permute在高維的功能性更強。
# 創造二維數據x,dim=0時候2,dim=1時候3
x = torch.randn(2,3)       'x.shape  →  [2,3]'
# 創造三維數據y,dim=0時候2,dim=1時候3,dim=2時候4
y = torch.randn(2,3,4)   'y.shape  →  [2,3,4]'
# 對於transpose
x.transpose(0,1)     'shape→[3,2] '  
x.transpose(1,0)     'shape→[3,2] '  
y.transpose(0,1)     'shape→[3,2,4]' 
y.transpose(0,2,1)  'error,操作不了多維'

# 對於permute()
x.permute(0,1)     'shape→[2,3]'
x.permute(1,0)     'shape→[3,2], 注意返回的shape不同於x.transpose(1,0) '
y.permute(0,1)     "error 沒有傳入所有維度數"
y.permute(1,0,2)  'shape→[3,2,4]'
 
         
合法性不同
torch.transpose(x)合法, x.transpose()合法。
tensor.permute(x)不合法,x.permute()合法。

參考第二點的舉例

操作dim不同:
transpose()只能一次操作兩個維度;permute()可以一次操作多維數據,且必須傳入所有維度數,因為permute()的參數是int*。
  1. transpose()中的dim沒有數的大小區分;permute()中的dim有數的大小區分

舉例,注意后面的shape

 

# 對於transpose,不區分dim大小
x1 = x.transpose(0,1)   'shape→[3,2] '  
x2 = x.transpose(1,0)   '也變換了,shape→[3,2] '  
print(torch.equal(x1,x2))
' True ,value和shape都一樣'

# 對於permute()
x1 = x.permute(0,1)     '不同transpose,shape→[2,3] '  
x2 = x.permute(1,0)     'shape→[3,2] '  
print(torch.equal(x1,x2))
'False,和transpose不同'

y1 = y.permute(0,1,2)     '保持不變,shape→[2,3,4] '  
y2 = y.permute(1,0,2)     'shape→[3,2,4] '  
y3 = y.permute(1,2,0)     'shape→[3,4,2] '  

 

view()函數改變通過轉置后的數據結構,導致報錯
RuntimeError: invalid argument 2: view size is not compatible with input tensor's....

這是因為tensor經過轉置后數據的內存地址不連續導致的,也就是tensor . is_contiguous()==False
雖然在torch里面,view函數相當於numpy的reshape,但是這時候reshape()可以改變該tensor結構,但是view()不可以

x = torch.rand(3,4)
x = x.transpose(0,1)
print(x.is_contiguous()) # 是否連續
'False'
# 會發現
x.view(3,4) '''
RuntimeError: invalid argument 2: view size is not compatible with input tensor's....
就是不連續導致的
'''
# 但是這樣是可以的。
x = x.contiguous()
x.view(3,4)

 

x = torch.rand(3,4)
x = x.permute(1,0) # 等價x = x.transpose(0,1)
x.reshape(3,4) '''這就不報錯了
說明x.reshape(3,4) 這個操作
等於x = x.contiguous().view()
盡管如此,但是我們還是不推薦使用reshape
除非為了獲取完全不同但是數據相同的克隆體
'''

調用contiguous()時,會強制拷貝一份tensor,讓它的布局和從頭創建的一毛一樣。

只需要記住了,每次在使用view()之前,該tensor只要使用了transpose()permute()這兩個函數一定要contiguous().

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM