1. 擴展Tensor維度
相信剛接觸Pytorch的寶寶們,會遇到這樣一個問題,輸入的數據維度和實驗需要維度不一致,輸入的可能是2維數據或3維數據,實驗需要用到3維或4維數據,那么我們需要擴展這個維度。其實特別簡單,只要對數據加一個擴展維度方法就可以了。
1.1 torch.unsqueeze(self: Tensor, dim: _int)
torch.unsqueeze(self: Tensor, dim: _int)
參數說明:self:輸入的tensor數據,dim:要對哪個維度擴展就輸入那個維度的整數,可以輸入0,1,2……
1.2 Code
第一種方式,輸入數據后直接加unsqueeze()
擴展第一維和第二維為1
1 import torch 2 3 4 def reset_unsqueeze1(): 5 data = torch.rand([3, 3]) 6 data1 = data.unsqueeze(dim=0).unsqueeze(dim=1) 7 print("data_size: ", data.shape) 8 print("data: ", data) 9 print("data1_size: ", data1.shape) 10 print("data1: ", data1)
結果顯示
第二種方式,用torch.unsqueeze()
1 import torch 2 3 4 def reset_unsqueeze2(): 5 data = torch.rand([3, 3]) 6 data1 = torch.unsqueeze(data, dim=0) 7 print("data_size: ", data.shape) 8 print("data: ", data) 9 print("data1_size: ", data1.shape) 10 print("data1: ", data1)
結果顯示
2. 壓縮Tensor維度
2.1 torch.squeeze(self: Tensor, dim: _int)
這個方法剛好和torch.unsqueeze()方法效果相反,壓縮Tensor維度。
2.2 Code
第一種方式,輸入數據后直接加squeeze()
1 import torch 2 3 4 def reset_squeeze1(): 5 data = torch.rand([1, 1, 3, 3]) 6 data1 = data.squeeze(dim=0).squeeze(dim=1) 7 print("data_size: ", data.shape) 8 print("data: ", data) 9 print("data1_size: ", data1.shape) 10 print("data1: ", data1)
結果顯示
第二種方式,用torch.squeeze()
1 import torch 2 3 4 def reset_squeeze2(): 5 data = torch.rand([1, 1, 3, 3]) 6 data1 = torch.squeeze(data, dim=0) 7 print("data_size: ", data.shape) 8 print("data: ", data) 9 print("data1_size: ", data1.shape) 10 print("data1: ", data1)
結果顯示