在pytorch
中轉置用的函數就只有這兩個:transpose()和
permute(),本文將詳細地介紹這兩個函數以及它們之間的區別。
transpose()
torch.transpose(input, dim0, dim1, out=None) → Tensor
函數返回輸入矩陣input
的轉置。交換維度dim0
和dim1
參數:
- input (Tensor) – 輸入張量,必填
- dim0 (int) – 轉置的第一維,默認0,可選
- dim1 (int) – 轉置的第二維,默認1,可選
注意只能有兩個相關的交換的位置參數。
例子:
>>> x = torch.randn(2, 3) >>> x tensor([[ 1.0028, -0.9893, 0.5809], [-0.1669, 0.7299, 0.4942]]) >>> torch.transpose(x, 0, 1) tensor([[ 1.0028, -0.1669], [-0.9893, 0.7299], [ 0.5809, 0.4942]])
permute()
參數: dims (int…*)-換位順序,必填
例子:
>>> x = torch.randn(2, 3, 5) >>> x.size() torch.Size([2, 3, 5]) >>> x.permute(2, 0, 1).size() torch.Size([5, 2, 3])
transpose與permute的異同
- permute相當於可以同時操作於tensor的若干維度,transpose只能同時作用於tensor的兩個維度;
- torch.transpose(x)合法, x.transpose()合法。torch.permute(x)不合法,x.permute()合法。
- 與contiguous、view函數之關聯。contiguous:view只能作用在contiguous的variable上,如果在view之前調用了transpose、permute等,就需要調用contiguous()來返回一個contiguous copy;一種可能的解釋是:有些tensor並不是占用一整塊內存,而是由不同的數據塊組成,而tensor的view()操作依賴於內存是整塊的,這時只需要執行contiguous()這個函數,把tensor變成在內存中連續分布的形式;判斷ternsor是否為contiguous,可以調用torch.Tensor.is_contiguous()函數:
import torch x = torch.ones(10, 10) x.is_contiguous() # True x.transpose(0, 1).is_contiguous() # False x.transpose(0, 1).contiguous().is_contiguous() # True
另:在pytorch的最新版本0.4版本中,增加了torch.reshape(),與 numpy.reshape() 的功能類似,大致相當於 tensor.contiguous().view(),這樣就省去了對tensor做view()變換前,調用contiguous()的麻煩;