torch.repeat 和 torch.repeat_interleave


** 结论
torch.repeat: 输入张量的从后往前的后面维度对应按照repeat中大小进行repeat操作(所以 输入张量维度>= repeat维度)。 假设输入张量为(a,b,c),repeat(x,y),则为b维度repeat x倍,c维度repeat y倍;最终输出维度为(a, bx, cy)
torch.repeat_interleave: 输入张量按照指定维度进行扩展,假设输入张量为(2,2),torch.repeat_interleave(y, 3, dim=1), 原输入张量大小为(2,2),则在维度1扩展3倍,最终为(2,6)。如果没有指定dim,则会将输入拉张开为1维向量再进行扩展

1.torch.repeat

x = torch.tensor([1, 2, 3])
x.repeat(4, 2), x.repeat(4, 2).shape, x.repeat(4, 2, 1).shape, x.repeat(2)
输出
(tensor([[1, 2, 3, 1, 2, 3],
     [1, 2, 3, 1, 2, 3],
     [1, 2, 3, 1, 2, 3],
     [1, 2, 3, 1, 2, 3]]),

torch.Size([4, 6]),
torch.Size([4, 2, 3]),
tensor([1, 2, 3, 1, 2, 3]))

2.torch.repeat_interleave

x = torch.tensor([1, 2, 3])
x.repeat_interleave(2)
tensor([1, 1, 2, 2, 3, 3])

y = torch.tensor([[1, 2], [3, 4]])
torch.repeat_interleave(y, 2)
tensor([1, 1, 2, 2, 3, 3, 4, 4])

torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])

torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
tensor([[1, 2],
[3, 4],
[3, 4]])


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM