pytorch中torch.cat(),torch.chunk(),torch.split()函數的使用方法


一、torch.cat()函數

熟悉C字符串的同學們應該都用過strcat()函數,這個函數在C/C++程序中用於連接2個C字符串。在pytorch中,同樣有這樣的函數,那就是torch.cat()函數.
先上源碼定義:torch.cat(tensors,dim=0,out=None)

  • 第一個參數tensors是你想要連接的若干個張量,按你所傳入的順序進行連接,注意每一個張量需要形狀相同,或者更准確的說,進行行連接的張量要求列數相同,進行列連接的張量要求行數相同
  • 第二個參數dim表示維度,dim=0則表示按行連接,dim=1表示按列連接
a=torch.tensor([[1,2,3,4],[1,2,3,4]])
b=torch.tensor([[1,2,3,4,5],[1,2,3,4,5]])
print(torch.cat((a,b),1))
#輸出結果為:
tensor([[1, 2, 3, 4, 1, 2, 3, 4, 5],
        [1, 2, 3, 4, 1, 2, 3, 4, 5]])

二、torch.chunk()函數

torch.cat()函數是把各個tensor連接起來,這里的torch.chunk()的作用是把一個tensor均勻分割成若干個小tensor
源碼定義:torch.chunk(intput,chunks,dim=0)

  • 第一個參數input是你想要分割的tensor
  • 第二個參數chunks是你想均勻分割的份數,如果該tensor在你要進行分割的維度上的size不能被chunks整除,則最后一份會略小(也可能為空)
  • 第三個參數表示分割維度,dim=0按行分割,dim=1表示按列分割
  • 該函數返回由小tensor組成的list
c=torch.tensor([[1,4,7,9,11],[2,5,8,9,13]])
print(torch.chunk(c,3,1))
#輸出結果為:
(tensor([[1, 4],
        [2, 5]]), tensor([[7, 9],
        [8, 9]]), tensor([[11],
        [13]]))

三、torch.split()函數

這個函數可以說是torch.chunk()函數的升級版本,它不僅可以按份數均勻分割,還可以按特定方案進行分割。
源碼定義:torch.split(tensor,split_size_or_sections,dim=0)

  • 第一個參數是待分割張量
  • 第二個參數有兩種形式。
    一種是分割份數,這就和torch.chunk()一樣了。
    第二種這是分割方案,這是一個list,待分割張量將會分割為len(list)份,每一份的大小取決於list中的元素
  • 第三個參數為分割維度
section=[1,2,1,2,2]
d=torch.randn(8,4)
print(torch.split(d,section,dim=0))
#輸出結果為:
(tensor([[ 0.5388, -0.8537,  0.5539,  0.7793]]), tensor([[ 0.1270,  2.6241, -0.7594,  0.4644],
        [ 0.8160,  0.5553,  0.1234, -1.1157]]), tensor([[-0.4433, -0.3093, -2.0134, -0.4277]]), tensor([[-0.4297,  0.2532,  0.2789, -0.3068],
        [ 1.4208, -0.1202,  0.9256, -1.2127]]), tensor([[ 0.3542, -0.4656,  1.2683,  0.8753],
        [-0.2786, -0.2180,  0.3991,  0.5658]]))


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM