pytorch 中改變tensor維度(transpose)、拼接(cat)、壓縮(squeeze)詳解


具體示例如下,注意觀察維度的變化

1.改變tensor維度的操作:transpose、view、permute、t()、expand、repeat

#coding=utf-8
import  torch

def change_tensor_shape():
    x=torch.randn(2,4,3)
    s=x.transpose(1,2) #shape=[2,3,4]
    y=x.view(2,3,4) #shape=[2,3,4]
    z=x.permute(0,2,1) #shape=[2,3,4]

    #tensor.t()只能轉化 a 2D tensor
    m=torch.randn(2,3)#shape=[2,3]
    n=m.t()#shape=[3,2]
    print(m)
    print(n)

    #返回當前張量在某個維度為1擴展為更大的張量
    x = torch.Tensor([[1], [2], [3]])#shape=[3,1]
    t=x.expand(3, 4)
    print(t)
    '''
    tensor([[1., 1., 1., 1.],
        [2., 2., 2., 2.],
        [3., 3., 3., 3.]])
    '''

    #沿着特定的維度重復這個張量
    x=torch.Tensor([[1,2,3]])
    t=x.repeat(3, 2)
    print(t)
    '''
    tensor([[1., 2., 3., 1., 2., 3.],
        [1., 2., 3., 1., 2., 3.],
        [1., 2., 3., 1., 2., 3.]])
    '''
    x = torch.randn(2, 3, 4)
    t=x.repeat(2, 1, 3) #shape=[4, 3, 12]

if __name__=='__main__':
    change_tensor_shape()

2.tensor的拼接:cat、stack

除了要拼接的維度可以不相等,其他維度必須相等

#coding=utf-8
import  torch


def cat_and_stack():

    x = torch.randn(2,3,6)
    y = torch.randn(2,4,6)
    c=torch.cat((x,y),1)
    #c=(2*7*6)
    print(c.size)

    """
    而stack則會增加新的維度。
    如對兩個1*2維的tensor在第0個維度上stack,則會變為2*1*2的tensor;在第1個維度上stack,則會變為1*2*2的tensor。
    """
    a = torch.rand((1, 2))
    b = torch.rand((1, 2))
    c = torch.stack((a, b), 0)
    print(c.size())

if __name__=='__main__':
    cat_and_stack()

 3.壓縮和擴展維度:改變tensor中只有1個維度的tensor

 torch.squeeze(input, dim=None, out=None) → Tensor

除去輸入張量input中數值為1的維度,並返回新的張量。如果輸入張量的形狀為(A×1×B×C×1×D) 那么輸出張量的形狀為(A×B×C×D)

當通過dim參數指定維度時,維度壓縮操作只會在指定的維度上進行。如果輸入向量的形狀為(A×1×B),
squeeze(input, 0)會保持張量的維度不變,只有在執行squeeze(input, 1)時,輸入張量的形狀會被壓縮至(A×B) 。

如果一個張量只有1個維度,那么它不會受到上述方法的影響。

#coding=utf-8
import  torch


def squeeze_tensor():
    x = torch.Tensor(1,3)
    y=torch.squeeze(x, 0)
    print("y:",y)
    y=torch.unsqueeze(y, 1)
    print("y:",y)

if __name__=='__main__':
    squeeze_tensor()


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM