在pytorch中,我們經常對張量Tensor的維度進行壓縮或者擴充(壓縮或者擴充的維度為1)。其中經常使用的是squeeze()
函數和unsqueeze
函數;
squeeze在英文中的意思就是“擠、壓”,所以故名思議,squeeze()
函數就是對張量的維度進行減少的操作,話不多說,我們直接看下例子:
import torch
#定義兩個整型的張量a,b
a = torch.IntTensor([[1,2,3],[4,5,6]])
b = torch.IntTensor([[[1,2,3],[4,5,6]]])
#看一下a,b的形狀
print(a.shape)
print(b.shape)
'''
===output===
torch.Size([2, 3])
torch.Size([1, 2, 3])
'''
#我們看到張量b比較膨脹,有三個維度:1*2*3,所以我們要擠壓一下張量b的第0個維度(因為是1才能擠壓,否則沒有效果)
c = torch.squeeze(b,0) # 對應的維度為第0維
print(c.shape)
'''
===output===
torch.Size([2, 3])
'''
#那如果想想張量a膨脹一下,怎么辦
c = torch.unsqueeze(a,0)
print(c.shape)
'''
===output===
torch.Size([1, 2, 3])
'''
#可以看到張量a在第0維也膨脹了, 如果你看不慣的話,再壓縮一下它。
另外,squeeze()
函數和unsqueeze()
函數還有另一種寫法,直接用張量類型的變量來調用這兩個函數:
c = a.unsqueeze(0)
print(c.shape)
'''
===output===
torch.Size([1, 2, 3])
'''
你看出差別了么?這里直接用張量變量a
來調用了unsqueeze()
函數,當然squeeze()
也是一樣的,不信你可以試試_