1、cat拼接
- 功能:通過dim指定維度,在當前指定維度上直接拼接
- 默認是dim=0
- 指定的dim上,維度可以不相同,其他dim上維度必須相同,不然會報錯。
1)拼接兩個維度相同的數
a = torch.rand(2, 3, 2)
a
# 輸出:
tensor([[[0.6072, 0.6531],
[0.2023, 0.2506],
[0.0590, 0.3390]],
[[0.3994, 0.0110],
[0.3615, 0.3826],
[0.3033, 0.3096]]])
b = torch.rand(2, 3, 2) # 定義b與a大小相同
b
# 輸出:
tensor([[[0.6144, 0.4561],
[0.9263, 0.0644],
[0.2838, 0.3456]],
[[0.1126, 0.5303],
[0.8140, 0.5715],
[0.7627, 0.5095]]])
# dim選定合並的維度
torch.cat([a, b]) # 不指定dim時,默認是0
# 輸出:
tensor([[[0.6072, 0.6531],
[0.2023, 0.2506],
[0.0590, 0.3390]],
[[0.3994, 0.0110],
[0.3615, 0.3826],
[0.3033, 0.3096]],
[[0.6144, 0.4561],
[0.9263, 0.0644],
[0.2838, 0.3456]],
[[0.1126, 0.5303],
[0.8140, 0.5715],
[0.7627, 0.5095]]])
# 選定合並的維度dim=0
torch.cat([a, b], dim=0) # 指定dim=0,可以看到結果和上面的是一樣的
# 輸出:
tensor([[[0.6072, 0.6531],
[0.2023, 0.2506],
[0.0590, 0.3390]],
[[0.3994, 0.0110],
[0.3615, 0.3826],
[0.3033, 0.3096]],
[[0.6144, 0.4561],
[0.9263, 0.0644],
[0.2838, 0.3456]],
[[0.1126, 0.5303],
[0.8140, 0.5715],
[0.7627, 0.5095]]])
# 選定合並的維度dim=1
torch.cat([a, b], dim=1)
# 輸出:
tensor([[[0.6072, 0.6531],
[0.2023, 0.2506],
[0.0590, 0.3390],
[0.6144, 0.4561],
[0.9263, 0.0644],
[0.2838, 0.3456]],
[[0.3994, 0.0110],
[0.3615, 0.3826],
[0.3033, 0.3096],
[0.1126, 0.5303],
[0.8140, 0.5715],
[0.7627, 0.5095]]])
# 選定合並的維度dim=2
torch.cat([a, b], dim=2)
# 輸出:
tensor([[[0.6072, 0.6531, 0.6144, 0.4561],
[0.2023, 0.2506, 0.9263, 0.0644],
[0.0590, 0.3390, 0.2838, 0.3456]],
[[0.3994, 0.0110, 0.1126, 0.5303],
[0.3615, 0.3826, 0.8140, 0.5715],
[0.3033, 0.3096, 0.7627, 0.5095]]])
2)拼接兩個維度不同的數
結合上面維度相同的數對比,便於理解
a = torch.rand(2, 3, 2)
a
# 輸出:
tensor([[[0.6447, 0.9758],
[0.0688, 0.9082],
[0.0083, 0.0109]],
[[0.5239, 0.1217],
[0.9562, 0.6831],
[0.8691, 0.2769]]])
b = torch.rand(2, 2, 2)
b
# 輸出:
tensor([[[0.3604, 0.7585],
[0.7831, 0.0439]],
[[0.2040, 0.5002],
[0.8878, 0.5973]]])
# 不指定dim:
torch.cat([a, b])
# 因為dim默認是0,且a,b的dim[1]的大小不等(a是3, b是2),所以導致會報錯
# 輸出:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-32-5484713fecdf> in <module>
----> 1 torch.cat([a, b])
2 # 輸出:
RuntimeError: inv
# 可以看到,此時的ab,因為只有dim[1]不同,所以如果要用cat合並,只能在dim=1上合並
torch.cat([a, b], dim=1)
# 輸出:
tensor([[[0.6447, 0.9758],
[0.0688, 0.9082],
[0.0083, 0.0109],
[0.3604, 0.7585],
[0.7831, 0.0439]],
[[0.5239, 0.1217],
[0.9562, 0.6831],
[0.8691, 0.2769],
[0.2040, 0.5002],
[0.8878, 0.5973]]])
2.stack拼接
- 與cat不同的是,stack是在拼接的同時,在指定dim處插入維度后拼接。
- 可以理解為:stack是在指定維度處,分別為兩個維度數據加上一層[]后,再進行拼接。
- 對比cat會發現,cat的相同維度的兩部分數據是在一個[]里面,而stack的兩部分數據分別是在2個[]里面
- stack拼接的兩個數據,其所有維度必須相同
- 默認dim=0
a = torch.rand(2, 5)
a
# 輸出:
tensor([[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
[0.6929, 0.4945, 0.0631, 0.4546, 0.6918]])
b = torch.rand(2, 5)
b
# 輸出:
tensor([[0.7893, 0.4141, 0.2971, 0.6791, 0.9791],
[0.4722, 0.7540, 0.5282, 0.0625, 0.0448]])
# 默認dim=0。將兩個數據直接拼接
torch.stack([a, b])
# 輸出:
tensor([[[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
[0.6929, 0.4945, 0.0631, 0.4546, 0.6918]],
[[0.7893, 0.4141, 0.2971, 0.6791, 0.9791],
[0.4722, 0.7540, 0.5282, 0.0625, 0.0448]]])
# 指定dim=0
# 此處可以對比cat拼接,發現同樣是dim=0,cat的數據在一個[]里面。此處是數據被分成了2段(在兩個[]里面)
torch.stack([a, b], dim=0) # 可以看到和上面默認的結果一致
# 輸出:
tensor([[[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
[0.6929, 0.4945, 0.0631, 0.4546, 0.6918]],
[[0.7893, 0.4141, 0.2971, 0.6791, 0.9791],
[0.4722, 0.7540, 0.5282, 0.0625, 0.0448]]])
# 指定dim=1。將數據在dim=1維度上拼接。
# 注意:結果后上面dim=0有區別。
torch.stack([a, b], dim=1)
# 輸出:
tensor([[[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
[0.7893, 0.4141, 0.2971, 0.6791, 0.9791]],
[[0.6929, 0.4945, 0.0631, 0.4546, 0.6918],
[0.4722, 0.7540, 0.5282, 0.0625, 0.0448]]])
# 指定dim=2。將數據在dim=2維度上拼接。
torch.stack([a, b], dim=2)
# 輸出:
tensor([[[0.2214, 0.7893],
[0.2666, 0.4141],
[0.6486, 0.2971],
[0.7050, 0.6791],
[0.4259, 0.9791]],
[[0.6929, 0.4722],
[0.4945, 0.7540],
[0.0631, 0.5282],
[0.4546, 0.0625],
[0.6918, 0.0448]]])
3、split拆分
- 指定拆分dim
- 給定拆分后的數據大小
a = torch.rand(4, 3, 2)
a
# 輸出:
tensor([[[0.5790, 0.6024],
[0.4730, 0.0734],
[0.2274, 0.7212]],
[[0.7051, 0.1568],
[0.5890, 0.1075],
[0.7469, 0.0659]],
[[0.7780, 0.5424],
[0.4344, 0.8551],
[0.6729, 0.7372]],
[[0.1669, 0.8596],
[0.9490, 0.8378],
[0.7889, 0.2192]]])
# 默認情況下dim=0
# 因為dim=0的大小是4,所以拆分為2 + 2 = 4,或者1+3=4.。。均可
a.split([2, 2])
# 輸出:
(tensor([[[0.5790, 0.6024],
[0.4730, 0.0734],
[0.2274, 0.7212]],
[[0.7051, 0.1568],
[0.5890, 0.1075],
[0.7469, 0.0659]]]),
tensor([[[0.7780, 0.5424],
[0.4344, 0.8551],
[0.6729, 0.7372]],
[[0.1669, 0.8596],
[0.9490, 0.8378],
[0.7889, 0.2192]]]))
# 因為dim=1的大小是3,所以拆分為2 + 1 = 3
a.split([2, 1], dim=1)
# 輸出:
(tensor([[[0.5790, 0.6024],
[0.4730, 0.0734]],
[[0.7051, 0.1568],
[0.5890, 0.1075]],
[[0.7780, 0.5424],
[0.4344, 0.8551]],
[[0.1669, 0.8596],
[0.9490, 0.8378]]]),
tensor([[[0.2274, 0.7212]],
[[0.7469, 0.0659]],
[[0.6729, 0.7372]],
[[0.7889, 0.2192]]]))
# 因為dim=2的大小是2,所以拆分為1 + 1 = 2
a.split([1, 1], dim=2)
# 輸出:
(tensor([[[0.5790],
[0.4730],
[0.2274]],
[[0.7051],
[0.5890],
[0.7469]],
[[0.7780],
[0.4344],
[0.6729]],
[[0.1669],
[0.9490],
[0.7889]]]),
tensor([[[0.6024],
[0.0734],
[0.7212]],
[[0.1568],
[0.1075],
[0.0659]],
[[0.5424],
[0.8551],
[0.7372]],
[[0.8596],
[0.8378],
[0.2192]]]))
chunk拆分
- chunk是在指定dim下給定,平均拆分的個數
- 如果給定個數不能平均拆分當前維度,則會取比給定個數小的,能平均拆分數據的,最大的個數
- dim默認是0
a
# 輸出:
tensor([[[0.5790, 0.6024],
[0.4730, 0.0734],
[0.2274, 0.7212]],
[[0.7051, 0.1568],
[0.5890, 0.1075],
[0.7469, 0.0659]],
[[0.7780, 0.5424],
[0.4344, 0.8551],
[0.6729, 0.7372]],
[[0.1669, 0.8596],
[0.9490, 0.8378],
[0.7889, 0.2192]]])
# 默認dim=0
# 在dim=0上,將數據平均分成4份
a.chunk(4)
# 輸出:
(tensor([[[0.5790, 0.6024],
[0.4730, 0.0734],
[0.2274, 0.7212]]]),
tensor([[[0.7051, 0.1568],
[0.5890, 0.1075],
[0.7469, 0.0659]]]),
tensor([[[0.7780, 0.5424],
[0.4344, 0.8551],
[0.6729, 0.7372]]]),
tensor([[[0.1669, 0.8596],
[0.9490, 0.8378],
[0.7889, 0.2192]]]))
# 在dim=0上,將數據平均分成4份
# 因為4不能被3整除,且比3小,能把4整除的數是2。所以,雖然給定是3,其實得到的結果為2個部分。
a.chunk(3, dim=0)
# 輸出:
(tensor([[[0.5790, 0.6024],
[0.4730, 0.0734],
[0.2274, 0.7212]],
[[0.7051, 0.1568],
[0.5890, 0.1075],
[0.7469, 0.0659]]]),
tensor([[[0.7780, 0.5424],
[0.4344, 0.8551],
[0.6729, 0.7372]],
[[0.1669, 0.8596],
[0.9490, 0.8378],
[0.7889, 0.2192]]]))
# 在dim=1上,將數據平均分成3份
a.chunk(3, dim=1)
# 輸出:
(tensor([[[0.5790, 0.6024]],
[[0.7051, 0.1568]],
[[0.7780, 0.5424]],
[[0.1669, 0.8596]]]),
tensor([[[0.4730, 0.0734]],
[[0.5890, 0.1075]],
[[0.4344, 0.8551]],
[[0.9490, 0.8378]]]),
tensor([[[0.2274, 0.7212]],
[[0.7469, 0.0659]],
[[0.6729, 0.7372]],
[[0.7889, 0.2192]]]))
# 在dim=2上,將數據平均分成3份
a.chunk(2, dim=2)
# 輸出:
(tensor([[[0.5790],
[0.4730],
[0.2274]],
[[0.7051],
[0.5890],
[0.7469]],
[[0.7780],
[0.4344],
[0.6729]],
[[0.1669],
[0.9490],
[0.7889]]]),
tensor([[[0.6024],
[0.0734],
[0.7212]],
[[0.1568],
[0.1075],
[0.0659]],
[[0.5424],
[0.8551],
[0.7372]],
[[0.8596],
[0.8378],
[0.2192]]]))