pytorch中的unsqueeze函數和squeeze函數


在pytorch中,我們經常對張量Tensor的維度進行壓縮或者擴充(壓縮或者擴充的維度為1)。其中經常使用的是squeeze()函數和unsqueeze函數;
squeeze在英文中的意思就是“擠、壓”,所以故名思議,squeeze()函數就是對張量的維度進行減少的操作,話不多說,我們直接看下例子:

import torch
#定義兩個整型的張量a,b
a = torch.IntTensor([[1,2,3],[4,5,6]])
b = torch.IntTensor([[[1,2,3],[4,5,6]]])
#看一下a,b的形狀
print(a.shape)
print(b.shape)
'''
===output===
torch.Size([2, 3])
torch.Size([1, 2, 3])
'''

#我們看到張量b比較膨脹,有三個維度:1*2*3,所以我們要擠壓一下張量b的第0個維度(因為是1才能擠壓,否則沒有效果)
c = torch.squeeze(b,0)  # 對應的維度為第0維
print(c.shape)
'''
===output===
torch.Size([2, 3])
'''
#那如果想想張量a膨脹一下,怎么辦
c = torch.unsqueeze(a,0)
print(c.shape)
'''
===output===
torch.Size([1, 2, 3])
'''
#可以看到張量a在第0維也膨脹了, 如果你看不慣的話,再壓縮一下它。

另外,squeeze()函數和unsqueeze()函數還有另一種寫法,直接用張量類型的變量來調用這兩個函數:

c = a.unsqueeze(0)
print(c.shape)
'''
===output===
torch.Size([1, 2, 3])
'''

你看出差別了么?這里直接用張量變量a來調用了unsqueeze()函數,當然squeeze()也是一樣的,不信你可以試試_


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM