说明:
移除指定维后,返回一个元组,包含了沿着指定维切片后的各个切片。
参数:
- tensor(Tensor) -- 输入张量
- dim(int) -- 删除的维度(按照某一个维度展开,返回切片)
注意:
不改变原来的tensor的shape,只是返回展开后的切片
import torch
t = torch.rand(3,3) #随机生成一个tensor
print(t)
print(t.shape)
r = torch.unbind(t,dim=0)#dim = 0指定拆除的维度
print(r)
s = torch.unbind(t,dim=1)#dim = 1指定拆除的维度
print(s)
jupyter notebook输出结果: