torch 中的torch.squeeze()和torch.unsqueeze()


1. torch.squeeze(input, dim=None, out=None)

input是输入的参数,dim是指定要合并维度为1的所在维度

当dim=0时原样输出,当dim=1时合并维度为1的行,dim=2 合并维度为1的列,当所在的行和列的维度不为1时原样输出,

例如:

import torch as t

a=t.araneg(8).view(4,1,2)#生成四个一行两列的tensor

t.squeeze(a,dim=0)#原样输出tensor

结果为:

tensor([[[0, 1]],

        [[2, 3]],

        [[4, 5]],

        [[6, 7]]])
当dim=1时,由于行所在的维度为1,因此合并行,生成4行两列的tensor
t.squeeze(a,dim=1)
结果为:
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7]])
当dim=2时,由于列的维度为2,所以原样输出
t.squeeze(a,dim=2)
结果为:
tensor([[[0, 1]],

        [[2, 3]],

        [[4, 5]],

        [[6, 7]]])
但是我们将原来的tensor换成2个4行1列的tensor,当dim=2时,将会生成2行4列的tensor
import torch as t
a=t.arange(8).view(2,4,1)
t.squeeze(a,dim=2)
输出的结果为:
tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])

2. torch.unsqueeze(input, dim, out=None)

插入一个维度唯一的维度

dim=0原样输出,dim=1在山上插入维度为1 的维度,dim=2在列上插入维度为1 的维度

比如某一tensor为(2,4)

当dim=1时变成(2,1,4)两个1行4列的tensor

当dim=2时变成(2,4,1)变成两个4行1列的tensor

import torch as t
a=t.arange(8).view(2,4)

a

结果为:

tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])

1)t.unsqueeze(a,dim=0)#原样输出

结果为:

tensor([[[0, 1, 2, 3],
         [4, 5, 6, 7]]])

2)t.unsqueeze(a,dim=1)#在行的维度加1,变成2个1行4列,即(2,1,4)
tensor([[[0, 1, 2, 3]],

        [[4, 5, 6, 7]]])
3)t.unqueeze(a,dim=2)#在列的维度加1 变成2个4行1列,即(2,4,1)
tensor([[[0],
         [1],
         [2],
         [3]],

        [[4],
         [5],
         [6],
         [7]]])


 




免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM