Pytorch tensor的復制函數torch.repeat_interleave()


1. repeat_interleave(self: Tensor, repeats: _int, dim: Optional[_int]=None)

參數說明:

self: 傳入的數據為tensor

repeats: 復制的份數

dim: 要復制的維度,可設定為0/1/2.....

2. 例子

2.1 Code

此處定義了一個4維tensor,要對第2個維度復制,由原來的1變為3,即將設定dim=1。

 1 import torch
 2 
 3 
 4 def function():
 5     data1 = torch.rand([2, 1, 3, 3])
 6     print("data1_shape: ", data1.shape)
 7     print("data1: ", data1)
 8 
 9     data2 = torch.repeat_interleave(data1, repeats=3, dim=1)
10     print("data2_shape: ", data2.shape)
11     print("data2: ", data2)
12 
13 
14 if __name__ == '__main__':
15     function()
View Code

2.2 輸出顯示

即可看到輸入tensor形狀為[2, 1, 3, 3],經過repeat后,tensor變為[2, 3, 3, 3],並在第二維度上保持相同的數據。

 

 

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM