Pytorch-tensor維度的擴展,擠壓,擴張


數據本身不發生改變,數據的訪問方式發生了改變

1.維度的擴展

函數:unsqueeze()

# a是一個4維的
    a = torch.randn(4, 3, 28, 28)
    print('a.shape\n', a.shape)

    print('\n維度擴展(變成5維的):')
    print('第0維前加1維')
    print(a.unsqueeze(0).shape)
    print('第4維前加1維')
    print(a.unsqueeze(4).shape)
    print('在-1維前加1維')
    print(a.unsqueeze(-1).shape)
    print('在-4維前加1維')
    print(a.unsqueeze(-4).shape)
    print('在-5維前加1維')
    print(a.unsqueeze(-5).shape)

輸出結果

a.shape
 torch.Size([4, 3, 28, 28])

維度擴展(變成5維的):
第0維前加1維
torch.Size([1, 4, 3, 28, 28])
第4維前加1維
torch.Size([4, 3, 28, 28, 1])
在-1維前加1維
torch.Size([4, 3, 28, 28, 1])
在-4維前加1維
torch.Size([4, 1, 3, 28, 28])
在-5維前加1維
torch.Size([1, 4, 3, 28, 28])

注意,第5維前加1維,就會出錯

    # print(a.unsqueeze(5).shape)
    # Errot:Dimension out of range (expected to be in range of -5, 4], but got 5)

連續擴維:unsqueeze()

    # b是一個1維的
    b = torch.tensor([1.2, 2.3])
    print('b.shape\n', b.shape)
    print()
    # 0維之前插入1維,變成1,2]
    print(b.unsqueeze(0))
    print()
    # 1維之前插入1維,變成2,1]
    print(b.unsqueeze(1))

    # 連續擴維,然后再對某個維度進行擴張
    print(b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape)

輸出結果

b.shape
 torch.Size([2])

tensor([[1.2000, 2.3000]])

tensor([[1.2000],
        [2.3000]])
torch.Size([1, 2, 1, 1])

2.擠壓維度

函數:squeeze()

    # 擠壓維度,只會擠壓shape為1的維度,如果shape不是1的話,當前值就不會變
    c = torch.randn(1, 32, 1, 2)
    print(c.shape)
    print(c.squeeze(0).shape)
    print(c.squeeze(1).shape)  # shape不是1,不會變
    print(c.squeeze(2).shape)
    print(c.squeeze(3).shape)  # shape不是1,不會變

輸出結果

torch.Size([1, 32, 1, 2])
torch.Size([32, 1, 2])
torch.Size([1, 32, 1, 2])
torch.Size([1, 32, 2])
torch.Size([1, 32, 1, 2])

3.維度擴張

函數1:expand():擴張到多少,

    # shape的擴張
    # expand():對shape為1的進行擴展,對shape不為1的只能保持不變,因為不知道如何變換,會報錯

    d = torch.randn(1, 32, 1, 1)
    print(d.shape)
    print(d.expand(4, 32, 14, 14).shape)

輸出結果

torch.Size([1, 32, 1, 1])
torch.Size([4, 32, 14, 14])

函數2:repeat()方法,擴張多少倍

    d=torch.randn([1,32,4,5])
    print(d.shape)
    print(d.repeat(4,32,2,3).shape)

輸出結果

torch.Size([1, 32, 4, 5])
torch.Size([4, 1024, 8, 15])


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM