pytorch 數據拼接與拆分cat、stack、split、chunck


1、cat拼接

  • 功能:通過dim指定維度,在當前指定維度上直接拼接
  • 默認是dim=0
  • 指定的dim上,維度可以不相同,其他dim上維度必須相同,不然會報錯。

1)拼接兩個維度相同的數

a = torch.rand(2, 3, 2)
a
# 輸出:
    tensor([[[0.6072, 0.6531],
             [0.2023, 0.2506],
             [0.0590, 0.3390]],

            [[0.3994, 0.0110],
             [0.3615, 0.3826],
             [0.3033, 0.3096]]])
    
b = torch.rand(2, 3, 2)  # 定義b與a大小相同
b
# 輸出:
    tensor([[[0.6144, 0.4561],
             [0.9263, 0.0644],
             [0.2838, 0.3456]],

            [[0.1126, 0.5303],
             [0.8140, 0.5715],
             [0.7627, 0.5095]]])   

    
# dim選定合並的維度
torch.cat([a, b])  # 不指定dim時,默認是0
# 輸出:
    tensor([[[0.6072, 0.6531],
             [0.2023, 0.2506],
             [0.0590, 0.3390]],

            [[0.3994, 0.0110],
             [0.3615, 0.3826],
             [0.3033, 0.3096]],

            [[0.6144, 0.4561],
             [0.9263, 0.0644],
             [0.2838, 0.3456]],

            [[0.1126, 0.5303],
             [0.8140, 0.5715],
             [0.7627, 0.5095]]])

# 選定合並的維度dim=0
torch.cat([a, b], dim=0)  # 指定dim=0,可以看到結果和上面的是一樣的
# 輸出:   
    tensor([[[0.6072, 0.6531],
             [0.2023, 0.2506],
             [0.0590, 0.3390]],

            [[0.3994, 0.0110],
             [0.3615, 0.3826],
             [0.3033, 0.3096]],

            [[0.6144, 0.4561],
             [0.9263, 0.0644],
             [0.2838, 0.3456]],

            [[0.1126, 0.5303],
             [0.8140, 0.5715],
             [0.7627, 0.5095]]])
    
# 選定合並的維度dim=1
torch.cat([a, b], dim=1)
# 輸出:
    tensor([[[0.6072, 0.6531],
             [0.2023, 0.2506],
             [0.0590, 0.3390],
             [0.6144, 0.4561],
             [0.9263, 0.0644],
             [0.2838, 0.3456]],

            [[0.3994, 0.0110],
             [0.3615, 0.3826],
             [0.3033, 0.3096],
             [0.1126, 0.5303],
             [0.8140, 0.5715],
             [0.7627, 0.5095]]])    
    
# 選定合並的維度dim=2
torch.cat([a, b], dim=2)
# 輸出:
    tensor([[[0.6072, 0.6531, 0.6144, 0.4561],
             [0.2023, 0.2506, 0.9263, 0.0644],
             [0.0590, 0.3390, 0.2838, 0.3456]],

            [[0.3994, 0.0110, 0.1126, 0.5303],
             [0.3615, 0.3826, 0.8140, 0.5715],
             [0.3033, 0.3096, 0.7627, 0.5095]]])    

2)拼接兩個維度不同的數

結合上面維度相同的數對比,便於理解

a = torch.rand(2, 3, 2)
a
# 輸出:
    tensor([[[0.6447, 0.9758],
             [0.0688, 0.9082],
             [0.0083, 0.0109]],

            [[0.5239, 0.1217],
             [0.9562, 0.6831],
             [0.8691, 0.2769]]])
    
b = torch.rand(2, 2, 2)
b
# 輸出:   
tensor([[[0.3604, 0.7585],
         [0.7831, 0.0439]],

        [[0.2040, 0.5002],
         [0.8878, 0.5973]]])

# 不指定dim:
torch.cat([a, b])
# 因為dim默認是0,且a,b的dim[1]的大小不等(a是3, b是2),所以導致會報錯
# 輸出:   
    ---------------------------------------------------------------------------
    RuntimeError                              Traceback (most recent call last)
    <ipython-input-32-5484713fecdf> in <module>
    ----> 1 torch.cat([a, b])
          2 # 輸出:

    RuntimeError: inv
        
# 可以看到,此時的ab,因為只有dim[1]不同,所以如果要用cat合並,只能在dim=1上合並       
torch.cat([a, b], dim=1)
# 輸出:   
    tensor([[[0.6447, 0.9758],
             [0.0688, 0.9082],
             [0.0083, 0.0109],
             [0.3604, 0.7585],
             [0.7831, 0.0439]],

            [[0.5239, 0.1217],
             [0.9562, 0.6831],
             [0.8691, 0.2769],
             [0.2040, 0.5002],
             [0.8878, 0.5973]]])        

2.stack拼接

  • 與cat不同的是,stack是在拼接的同時,在指定dim處插入維度后拼接。
  • 可以理解為:stack是在指定維度處,分別為兩個維度數據加上一層[]后,再進行拼接。
  • 對比cat會發現,cat的相同維度的兩部分數據是在一個[]里面,而stack的兩部分數據分別是在2個[]里面
  • stack拼接的兩個數據,其所有維度必須相同
  • 默認dim=0
a = torch.rand(2, 5)
a
# 輸出:
    tensor([[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
            [0.6929, 0.4945, 0.0631, 0.4546, 0.6918]])

b = torch.rand(2, 5)
b
# 輸出:
    tensor([[0.7893, 0.4141, 0.2971, 0.6791, 0.9791],
            [0.4722, 0.7540, 0.5282, 0.0625, 0.0448]])
    
# 默認dim=0。將兩個數據直接拼接
torch.stack([a, b])
# 輸出:
    tensor([[[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
             [0.6929, 0.4945, 0.0631, 0.4546, 0.6918]],

            [[0.7893, 0.4141, 0.2971, 0.6791, 0.9791],
             [0.4722, 0.7540, 0.5282, 0.0625, 0.0448]]])    
    
# 指定dim=0
# 此處可以對比cat拼接,發現同樣是dim=0,cat的數據在一個[]里面。此處是數據被分成了2段(在兩個[]里面)
torch.stack([a, b], dim=0)  # 可以看到和上面默認的結果一致
# 輸出:
    tensor([[[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
             [0.6929, 0.4945, 0.0631, 0.4546, 0.6918]],

            [[0.7893, 0.4141, 0.2971, 0.6791, 0.9791],
             [0.4722, 0.7540, 0.5282, 0.0625, 0.0448]]])    
    
# 指定dim=1。將數據在dim=1維度上拼接。
# 注意:結果后上面dim=0有區別。
torch.stack([a, b], dim=1)
# 輸出:
    tensor([[[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
             [0.7893, 0.4141, 0.2971, 0.6791, 0.9791]],

            [[0.6929, 0.4945, 0.0631, 0.4546, 0.6918],
             [0.4722, 0.7540, 0.5282, 0.0625, 0.0448]]])    
    
# 指定dim=2。將數據在dim=2維度上拼接。
torch.stack([a, b], dim=2)
# 輸出:
    tensor([[[0.2214, 0.7893],
             [0.2666, 0.4141],
             [0.6486, 0.2971],
             [0.7050, 0.6791],
             [0.4259, 0.9791]],

            [[0.6929, 0.4722],
             [0.4945, 0.7540],
             [0.0631, 0.5282],
             [0.4546, 0.0625],
             [0.6918, 0.0448]]])    

3、split拆分

  • 指定拆分dim
  • 給定拆分后的數據大小
a = torch.rand(4, 3, 2)
a
# 輸出:
    tensor([[[0.5790, 0.6024],
             [0.4730, 0.0734],
             [0.2274, 0.7212]],

            [[0.7051, 0.1568],
             [0.5890, 0.1075],
             [0.7469, 0.0659]],

            [[0.7780, 0.5424],
             [0.4344, 0.8551],
             [0.6729, 0.7372]],

            [[0.1669, 0.8596],
             [0.9490, 0.8378],
             [0.7889, 0.2192]]])

    
# 默認情況下dim=0    
# 因為dim=0的大小是4,所以拆分為2 + 2 = 4,或者1+3=4.。。均可
a.split([2, 2])  
# 輸出:
    (tensor([[[0.5790, 0.6024],
              [0.4730, 0.0734],
              [0.2274, 0.7212]],

             [[0.7051, 0.1568],
              [0.5890, 0.1075],
              [0.7469, 0.0659]]]),
     tensor([[[0.7780, 0.5424],
              [0.4344, 0.8551],
              [0.6729, 0.7372]],

             [[0.1669, 0.8596],
              [0.9490, 0.8378],
              [0.7889, 0.2192]]]))

    
# 因為dim=1的大小是3,所以拆分為2 + 1 = 3
a.split([2, 1], dim=1)
# 輸出:
    (tensor([[[0.5790, 0.6024],
              [0.4730, 0.0734]],

             [[0.7051, 0.1568],
              [0.5890, 0.1075]],

             [[0.7780, 0.5424],
              [0.4344, 0.8551]],

             [[0.1669, 0.8596],
              [0.9490, 0.8378]]]),
     tensor([[[0.2274, 0.7212]],

             [[0.7469, 0.0659]],

             [[0.6729, 0.7372]],

             [[0.7889, 0.2192]]]))    


# 因為dim=2的大小是2,所以拆分為1 + 1 = 2
a.split([1, 1], dim=2)
# 輸出:
(tensor([[[0.5790],
          [0.4730],
          [0.2274]],
 
         [[0.7051],
          [0.5890],
          [0.7469]],
 
         [[0.7780],
          [0.4344],
          [0.6729]],
 
         [[0.1669],
          [0.9490],
          [0.7889]]]),
 tensor([[[0.6024],
          [0.0734],
          [0.7212]],
 
         [[0.1568],
          [0.1075],
          [0.0659]],
 
         [[0.5424],
          [0.8551],
          [0.7372]],
 
         [[0.8596],
          [0.8378],
          [0.2192]]]))    

chunk拆分

  • chunk是在指定dim下給定,平均拆分的個數
  • 如果給定個數不能平均拆分當前維度,則會取比給定個數小的,能平均拆分數據的,最大的個數
  • dim默認是0
a
# 輸出:
    tensor([[[0.5790, 0.6024],
             [0.4730, 0.0734],
             [0.2274, 0.7212]],

            [[0.7051, 0.1568],
             [0.5890, 0.1075],
             [0.7469, 0.0659]],

            [[0.7780, 0.5424],
             [0.4344, 0.8551],
             [0.6729, 0.7372]],

            [[0.1669, 0.8596],
             [0.9490, 0.8378],
             [0.7889, 0.2192]]])
    
# 默認dim=0
# 在dim=0上,將數據平均分成4份
a.chunk(4)
# 輸出:
    (tensor([[[0.5790, 0.6024],
              [0.4730, 0.0734],
              [0.2274, 0.7212]]]),
     tensor([[[0.7051, 0.1568],
              [0.5890, 0.1075],
              [0.7469, 0.0659]]]),
     tensor([[[0.7780, 0.5424],
              [0.4344, 0.8551],
              [0.6729, 0.7372]]]),
     tensor([[[0.1669, 0.8596],
              [0.9490, 0.8378],
              [0.7889, 0.2192]]]))    
    
# 在dim=0上,將數據平均分成4份
# 因為4不能被3整除,且比3小,能把4整除的數是2。所以,雖然給定是3,其實得到的結果為2個部分。
a.chunk(3, dim=0)
# 輸出:
    (tensor([[[0.5790, 0.6024],
              [0.4730, 0.0734],
              [0.2274, 0.7212]],

             [[0.7051, 0.1568],
              [0.5890, 0.1075],
              [0.7469, 0.0659]]]),
     tensor([[[0.7780, 0.5424],
              [0.4344, 0.8551],
              [0.6729, 0.7372]],

             [[0.1669, 0.8596],
              [0.9490, 0.8378],
              [0.7889, 0.2192]]]))    
    
# 在dim=1上,將數據平均分成3份
a.chunk(3, dim=1)
# 輸出:
    (tensor([[[0.5790, 0.6024]],

             [[0.7051, 0.1568]],

             [[0.7780, 0.5424]],

             [[0.1669, 0.8596]]]),
     tensor([[[0.4730, 0.0734]],

             [[0.5890, 0.1075]],

             [[0.4344, 0.8551]],

             [[0.9490, 0.8378]]]),
     tensor([[[0.2274, 0.7212]],

             [[0.7469, 0.0659]],

             [[0.6729, 0.7372]],

             [[0.7889, 0.2192]]]))    
    
# 在dim=2上,將數據平均分成3份
a.chunk(2, dim=2)
# 輸出:
    (tensor([[[0.5790],
              [0.4730],
              [0.2274]],

             [[0.7051],
              [0.5890],
              [0.7469]],

             [[0.7780],
              [0.4344],
              [0.6729]],

             [[0.1669],
              [0.9490],
              [0.7889]]]),
     tensor([[[0.6024],
              [0.0734],
              [0.7212]],

             [[0.1568],
              [0.1075],
              [0.0659]],

             [[0.5424],
              [0.8551],
              [0.7372]],

             [[0.8596],
              [0.8378],
              [0.2192]]]))    


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM