Pytorch-拼接與拆分


引言

本篇介紹tensor的拼接與拆分。

拼接與拆分

  • cat
  • stack
  • split
  • chunk

cat

  • numpy中使用concat,在pytorch中使用更加簡寫的 cat
  • 完成一個拼接
  • 兩個向量維度相同,想要拼接的維度上的值可以不同,但是其它維度上的值必須相同。

舉個例子:還是按照前面的,想將這兩組班級的成績合並起來

a[class 1-4, students, scores]

b[class 5-9, students, scores]

1
2
3
4
5
In[4]: a = torch.rand(4,32,8)
In[5]: b = torch.rand(5,32,8)
In[6]: torch.cat([a,b],dim=0).shape
Out[6]: torch.Size([9, 32, 8])
# 結果就是9個班級的成績

理解cat:

  • 行拼接:[4, 4] 與 [5, 4] 以 dim=0(行)進行拼接 —> [9, 4] 9個班的成績合起來
  • 列拼接:[4, 5] 與 [4, 3] 以 dim=1(列)進行拼接 —> [4, 8] 每個班合成8項成績

理解cat

例2:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
In[7]: a1 = torch.rand(4,3,32,32)
In[8]: a2 = torch.rand(5,3,32,32)
In[9]: torch.cat([a1,a2],dim=0).shape # 合並第1維 理解上相當於合並batch
Out[9]: torch.Size([9, 3, 32, 32])
In[11]: a2 = torch.rand(4,1,32,32)
In[12]: torch.cat([a1,a2],dim=1).shape # 合並第2維 理解上相當於合並為 rgba
Out[12]: torch.Size([4, 4, 32, 32])
In[13]: a1 = torch.rand(4,3,16,32)
In[14]: a2 = torch.rand(4,3,16,32)
In[15]: torch.cat([a1,a2],dim=3).shape # 合並第3維 理解上相當於合並照片的上下兩半
Out[15]: torch.Size([4, 3, 16, 64])
In[17]: a1 = torch.rand(4,3,32,32)
In[18]: torch.cat([a1,a2],dim=0).shape
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0.

stack

  • 創造一個新的維度(代表了新的組別)
  • 要求兩個tensor的size完全相同
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
In[19]: a1 = torch.rand(4,3,16,32)
In[20]: a2 = torch.rand(4,3,16,32)
In[21]: torch.cat([a1,a2],dim=2).shape # 合並照片的上下部分
Out[21]: torch.Size([4, 3, 32, 32])
In[22]: torch.stack([a1,a2],dim=2).shape # 添加了一個維度 一個值代表上半部分,一個值代表下半部分。 這顯然是沒有cat合適的。
Out[22]: torch.Size([4, 3, 2, 16, 32])
In[23]: a = torch.rand(32,8)
In[24]: b = torch.rand(32,8)
In[25]: torch.stack([a,b],dim=0).shape # 將兩個班級的學生成績合並,添加一個新的維度,這個維度的每個值代表一個班級。顯然是比cat合適的。
Out[25]: torch.Size([2, 32, 8])

In[26]: a.shape
Out[26]: torch.Size([32, 8])
In[27]: b = torch.rand([30,8])
In[28]: torch.stack([a,b],dim=0)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0

split

  • 按長度進行拆分:單元長度/數量
  • 長度相同給一個固定值
  • 長度不同給一個列表
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
In[48]: a = torch.rand(32,8)
In[49]: b = torch.rand(32,8)
In[50]: c = torch.rand(32,8)
In[51]: d = torch.rand(32,8)
In[52]: e = torch.rand(32,8)
In[53]: f = torch.rand(32,8)
In[54]: s = torch.stack([a,b,c,d,e,f],dim=0)
In[55]: s.shape
Out[55]: torch.Size([6, 32, 8])
In[57]: aa,bb = s.split(3,dim=0) # 按數量切分,可以使用一個常數
In[58]: aa.shape, bb.shape
Out[58]: (torch.Size([3, 32, 8]), torch.Size([3, 32, 8]))
In[59]: cc,dd,ee = s.split([3,2,1],dim=0) # 按單位長度切分,可以使用一個列表
In[60]: cc.shape, dd.shape, ee.shape
Out[60]: (torch.Size([3, 32, 8]), torch.Size([2, 32, 8]), torch.Size([1, 32, 8]))

In[61]: ff,gg = s.split(6,dim=0) # 只切了一半,有一半不存在,所以報錯
ValueError: not enough values to unpack (expected 2, got 1)

chunk

  • 按數量進行拆分
1
2
3
4
5
6
7
8
In[63]: s.shape
Out[63]: torch.Size([6, 32, 8])
In[64]: aa,bb = s.chunk(2,dim=0)
In[65]: aa.shape, bb.shape
Out[65]: (torch.Size([3, 32, 8]), torch.Size([3, 32, 8]))
In[66]: cc,dd = s.split(3,dim=0)
In[67]: cc.shape,dd.shape
Out[67]: (torch.Size([3, 32, 8]), torch.Size([3, 32, 8]))

note:對於按數量切分:chunk中的參數是要切成幾份;split的常數是每份有幾個。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM