引言
拼接與拆分
- cat
- stack
- split
- chunk
cat
- numpy中使用concat,在pytorch中使用更加簡寫的 cat
- 完成一個拼接
- 兩個向量維度相同,想要拼接的維度上的值可以不同,但是其它維度上的值必須相同。
舉個例子:還是按照前面的,想將這兩組班級的成績合並起來
a[class 1-4, students, scores]
b[class 5-9, students, scores]
1 |
In[4]: a = torch.rand(4,32,8) |
理解cat:
- 行拼接:[4, 4] 與 [5, 4] 以 dim=0(行)進行拼接 —> [9, 4] 9個班的成績合起來
- 列拼接:[4, 5] 與 [4, 3] 以 dim=1(列)進行拼接 —> [4, 8] 每個班合成8項成績
例2:
1 |
In[7]: a1 = torch.rand(4,3,32,32) |
stack
- 創造一個新的維度(代表了新的組別)
- 要求兩個tensor的size完全相同
1 |
In[19]: a1 = torch.rand(4,3,16,32) |
split
- 按長度進行拆分:單元長度/數量
- 長度相同給一個固定值
- 長度不同給一個列表
1 |
In[48]: a = torch.rand(32,8) |
chunk
- 按數量進行拆分
1 |
In[63]: s.shape |
note:對於按數量切分:chunk中的參數是要切成幾份;split的常數是每份有幾個。