Pytorch 擴展Tensor維度、壓縮Tensor維度


1. 擴展Tensor維度

    相信剛接觸Pytorch的寶寶們,會遇到這樣一個問題,輸入的數據維度和實驗需要維度不一致,輸入的可能是2維數據或3維數據,實驗需要用到3維或4維數據,那么我們需要擴展這個維度。其實特別簡單,只要對數據加一個擴展維度方法就可以了。

1.1 torch.unsqueeze(self: Tensor, dim: _int)

  torch.unsqueeze(self: Tensor, dim: _int)

  參數說明:self:輸入的tensor數據,dim:要對哪個維度擴展就輸入那個維度的整數,可以輸入0,1,2……

1.2 Code

第一種方式,輸入數據后直接加unsqueeze()

  擴展第一維和第二維為1

 1 import torch
 2 
 3 
 4 def reset_unsqueeze1():
 5     data = torch.rand([3, 3])
 6     data1 = data.unsqueeze(dim=0).unsqueeze(dim=1)
 7     print("data_size: ", data.shape)
 8     print("data: ", data)
 9     print("data1_size: ", data1.shape)
10     print("data1: ", data1)

結果顯示

 第二種方式,用torch.unsqueeze()

 1 import torch
 2 
 3 
 4 def reset_unsqueeze2():
 5     data = torch.rand([3, 3])
 6     data1 = torch.unsqueeze(data, dim=0)
 7     print("data_size: ", data.shape)
 8     print("data: ", data)
 9     print("data1_size: ", data1.shape)
10     print("data1: ", data1)

 結果顯示

 

2. 壓縮Tensor維度

2.1 torch.squeeze(self: Tensor, dim: _int)

  這個方法剛好和torch.unsqueeze()方法效果相反,壓縮Tensor維度。

2.2 Code

第一種方式,輸入數據后直接加squeeze()

 1 import torch
 2 
 3 
 4 def reset_squeeze1():
 5     data = torch.rand([1, 1, 3, 3])
 6     data1 = data.squeeze(dim=0).squeeze(dim=1)
 7     print("data_size: ", data.shape)
 8     print("data: ", data)
 9     print("data1_size: ", data1.shape)
10     print("data1: ", data1)

結果顯示

 第二種方式,用torch.squeeze()

 1 import torch
 2 
 3 
 4 def reset_squeeze2():
 5     data = torch.rand([1, 1, 3, 3])
 6     data1 = torch.squeeze(data, dim=0)
 7     print("data_size: ", data.shape)
 8     print("data: ", data)
 9     print("data1_size: ", data1.shape)
10     print("data1: ", data1)

結果顯示

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM