PyTorch 兩大轉置函數 transpose() 和 permute()


PyTorch 兩大轉置函數 transpose() 和 permute(), 以及RuntimeError: invalid argument 2: view size is not compati


關心差別的可以直接看[3.不同點]和[4.連續性問題]
前言
在pytorch中轉置用的函數就只有這兩個

transpose()
permute()
注意只有transpose()有后綴格式:transpose_():后綴函數的作用是簡化如下代碼:

x = x.transpose(0,1)
等價於
x.transpose_()
# 相當於x = x + 1 簡化為 x+=1

這兩個函數都是交換維度的操作。有一些細微的區別

1. 官方文檔
transpose()
torch.transpose(input, dim0, dim1, out=None) → Tensor

函數返回輸入矩陣input的轉置。交換維度dim0和dim1

參數:

input (Tensor) – 輸入張量,必填
dim0 (int) – 轉置的第一維,默認0,可選
dim1 (int) – 轉置的第二維,默認1,可選
permute()
permute(dims) → Tensor

將tensor的維度換位。

參數:

dims (int…*)-換位順序,必填
2. 相同點
都是返回轉置后矩陣。
都可以操作高緯矩陣,permute在高維的功能性更強。
3.不同點
先定義我們后面用的數據如下

# 創造二維數據x,dim=0時候2,dim=1時候3
x = torch.randn(2,3) 'x.shape → [2,3]'
# 創造三維數據y,dim=0時候2,dim=1時候3,dim=2時候4
y = torch.randn(2,3,4) 'y.shape → [2,3,4]'

合法性不同
torch.transpose(x)合法, x.transpose()合法。
tensor.permute(x)不合法,x.permute()合法。

參考第二點的舉例

操作dim不同:
transpose()只能一次操作兩個維度;permute()可以一次操作多維數據,且必須傳入所有維度數,因為permute()的參數是int*。

舉例

# 對於transpose
x.transpose(0,1) 'shape→[3,2] '
x.transpose(1,0) 'shape→[3,2] '
y.transpose(0,1) 'shape→[3,2,4]'
y.transpose(0,2,1) 'error,操作不了多維'

# 對於permute()
x.permute(0,1) 'shape→[2,3]'
x.permute(1,0) 'shape→[3,2], 注意返回的shape不同於x.transpose(1,0) '
y.permute(0,1) "error 沒有傳入所有維度數"
y.permute(1,0,2) 'shape→[3,2,4]'

transpose()中的dim沒有數的大小區分;permute()中的dim有數的大小區分
舉例,注意后面的shape:

# 對於transpose,不區分dim大小
x1 = x.transpose(0,1) 'shape→[3,2] '
x2 = x.transpose(1,0) '也變換了,shape→[3,2] '
print(torch.equal(x1,x2))
' True ,value和shape都一樣'

# 對於permute()
x1 = x.permute(0,1) '不同transpose,shape→[2,3] '
x2 = x.permute(1,0) 'shape→[3,2] '
print(torch.equal(x1,x2))
'False,和transpose不同'

y1 = y.permute(0,1,2) '保持不變,shape→[2,3,4] '
y2 = y.permute(1,0,2) 'shape→[3,2,4] '
y3 = y.permute(1,2,0) 'shape→[3,4,2] '

4.關於連續contiguous()
經常有人用view()函數改變通過轉置后的數據結構,導致報錯
RuntimeError: invalid argument 2: view size is not compatible with input tensor's....

這是因為tensor經過轉置后數據的內存地址不連續導致的,也就是tensor . is_contiguous()==False
這時候reshape()可以改變該tensor結構,但是view()不可以,具體不同可以看view和reshape的區別
例子如下:

x = torch.rand(3,4)
x = x.transpose(0,1)
print(x.is_contiguous()) # 是否連續
'False'
# 再view會發現報錯
x.view(3,4)
'''報錯
RuntimeError: invalid argument 2: view size is not compatible with input tensor's....
'''

# 但是下面這樣是不會報錯。
x = x.contiguous()
x.view(3,4)

我們再看看reshape()

x = torch.rand(3,4)
x = x.permute(1,0) # 等價x = x.transpose(0,1)
x.reshape(3,4)
'''這就不報錯了
說明x.reshape(3,4) 這個操作
等於x = x.contiguous().view()
盡管如此,但是torch文檔中還是不推薦使用reshape
理由是除非為了獲取完全不同但是數據相同的克隆體
'''

調用contiguous()時,會強制拷貝一份tensor,讓它的布局和從頭創建的一毛一樣。
(這一段看文字你肯定不理解,你也可以不用理解,有空我會畫圖補上)

只需要記住了,每次在使用view()之前,該tensor只要使用了transpose()和permute()這兩個函數一定要contiguous().

5.總結
最重要的區別應該是上面的第三點和第四個。

另外,簡單的數據用transpose()就可以了,但是個人覺得不夠直觀,指向性弱了點;復雜維度的可以用permute(),對於維度的改變,一般更加精准。
————————————————
版權聲明:本文為CSDN博主「模糊包」的原創文章,遵循CC 4.0 BY-SA版權協議,轉載請附上原文出處鏈接及本聲明。
原文鏈接:https://blog.csdn.net/xinjieyuan/article/details/105232802


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM