torch.repeat 和 torch.repeat_interleave


** 結論
torch.repeat: 輸入張量的從后往前的后面維度對應按照repeat中大小進行repeat操作(所以 輸入張量維度>= repeat維度)。 假設輸入張量為(a,b,c),repeat(x,y),則為b維度repeat x倍,c維度repeat y倍;最終輸出維度為(a, bx, cy)
torch.repeat_interleave: 輸入張量按照指定維度進行擴展,假設輸入張量為(2,2),torch.repeat_interleave(y, 3, dim=1), 原輸入張量大小為(2,2),則在維度1擴展3倍,最終為(2,6)。如果沒有指定dim,則會將輸入拉張開為1維向量再進行擴展

1.torch.repeat

x = torch.tensor([1, 2, 3])
x.repeat(4, 2), x.repeat(4, 2).shape, x.repeat(4, 2, 1).shape, x.repeat(2)
輸出
(tensor([[1, 2, 3, 1, 2, 3],
     [1, 2, 3, 1, 2, 3],
     [1, 2, 3, 1, 2, 3],
     [1, 2, 3, 1, 2, 3]]),

torch.Size([4, 6]),
torch.Size([4, 2, 3]),
tensor([1, 2, 3, 1, 2, 3]))

2.torch.repeat_interleave

x = torch.tensor([1, 2, 3])
x.repeat_interleave(2)
tensor([1, 1, 2, 2, 3, 3])

y = torch.tensor([[1, 2], [3, 4]])
torch.repeat_interleave(y, 2)
tensor([1, 1, 2, 2, 3, 3, 4, 4])

torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])

torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
tensor([[1, 2],
[3, 4],
[3, 4]])


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM