Pytorch之permute函數


1、主要作用:變換tensor維度

example:

import torch
x = torch.randn(2, 3, 5)
print(x.size())
print(x.permute(2, 0, 1).size())

>>>torch.Size([2, 3, 5])
>>>torch.Size([5, 2, 3])

2、介紹一下transpose與permute的異同:

同:都是對tensor維度進行轉置;

異:permute函數可以對任意高維矩陣進行轉置,但沒有torch.permute()這個調用方式

torch.randn(2,3,4,5).permute(3,2,0,1).shape

>>>torch.Size([5, 4, 2, 3])

transpose只能操作2D矩陣的轉置,無法操作超過2個維度,所以要想實現多個維度的轉置,既可以用一次性的

permute,也可以多次使用transpose;

torch.randn(2,3,4,5).transpose(3,0).transpose(2,1).transpose(3,2).shape

>>>torch.Size([5, 4, 2, 3])

3、permute函數與contiguous、view函數的關聯

contiguous: view只能作用在contiguous的variable上,如果在view之前調用了transpose、permute等,就需要調用

contiguous()來返回一個contiguous的copy;

也就是說transpose、permute等操作會讓tensor變得在內存上不連續,因此要想view,就得讓tensor先連續;

解釋如下:有些tensor並不是占用一整塊內存,而是由不同的數據塊組成,而tensor的view()操作依賴於內存是整塊的,這時只需要執行contiguous()這個函數,把tensor變成在內存中連續分布的形式;

判斷ternsor是否為contiguous,可以調用torch.Tensor.is_contiguous()函數:

import torch 
x = torch.ones(10, 10) 
x.is_contiguous()                                 # True 
x.transpose(0, 1).is_contiguous()                 # False
x.transpose(0, 1).contiguous().is_contiguous()    # True

另:在pytorch的最新版本0.4版本中,增加了torch.reshape(),與 numpy.reshape() 的功能類似,大致相當於 tensor.contiguous().view(),這樣就省去了對tensor做view()變換前,調用contiguous()的麻煩;

3、permute與view函數功能

import torch
import numpy as np

a=np.array([[[1,2,3],[4,5,6]]])
unpermuted=torch.tensor(a)
print(unpermuted.size())              #  ——>  torch.Size([1, 2, 3])

permuted=unpermuted.permute(2,0,1)
print(permuted.size())                #  ——>  torch.Size([3, 1, 2])

view_test = unpermuted.view(1,3,2)
print(view_test.size())   

>>>torch.Size([1, 2, 3])
torch.Size([3, 1, 2])
torch.Size([1, 3, 2])

 參考:https://zhuanlan.zhihu.com/p/76583143


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM