一、torch.cat()函數
熟悉C字符串的同學們應該都用過strcat()函數,這個函數在C/C++程序中用於連接2個C字符串。在pytorch中,同樣有這樣的函數,那就是torch.cat()函數.
先上源碼定義:torch.cat(tensors,dim=0,out=None)
- 第一個參數tensors是你想要連接的若干個張量,按你所傳入的順序進行連接,注意每一個張量需要形狀相同,或者更准確的說,進行行連接的張量要求列數相同,進行列連接的張量要求行數相同
- 第二個參數dim表示維度,dim=0則表示按行連接,dim=1表示按列連接
a=torch.tensor([[1,2,3,4],[1,2,3,4]])
b=torch.tensor([[1,2,3,4,5],[1,2,3,4,5]])
print(torch.cat((a,b),1))
#輸出結果為:
tensor([[1, 2, 3, 4, 1, 2, 3, 4, 5],
[1, 2, 3, 4, 1, 2, 3, 4, 5]])
二、torch.chunk()函數
torch.cat()函數是把各個tensor連接起來,這里的torch.chunk()的作用是把一個tensor均勻分割成若干個小tensor
源碼定義:torch.chunk(intput,chunks,dim=0)
- 第一個參數input是你想要分割的tensor
- 第二個參數chunks是你想均勻分割的份數,如果該tensor在你要進行分割的維度上的size不能被chunks整除,則最后一份會略小(也可能為空)
- 第三個參數表示分割維度,dim=0按行分割,dim=1表示按列分割
- 該函數返回由小tensor組成的list
c=torch.tensor([[1,4,7,9,11],[2,5,8,9,13]])
print(torch.chunk(c,3,1))
#輸出結果為:
(tensor([[1, 4],
[2, 5]]), tensor([[7, 9],
[8, 9]]), tensor([[11],
[13]]))
三、torch.split()函數
這個函數可以說是torch.chunk()函數的升級版本,它不僅可以按份數均勻分割,還可以按特定方案進行分割。
源碼定義:torch.split(tensor,split_size_or_sections,dim=0)
- 第一個參數是待分割張量
- 第二個參數有兩種形式。
一種是分割份數,這就和torch.chunk()一樣了。
第二種這是分割方案,這是一個list,待分割張量將會分割為len(list)份,每一份的大小取決於list中的元素 - 第三個參數為分割維度
section=[1,2,1,2,2]
d=torch.randn(8,4)
print(torch.split(d,section,dim=0))
#輸出結果為:
(tensor([[ 0.5388, -0.8537, 0.5539, 0.7793]]), tensor([[ 0.1270, 2.6241, -0.7594, 0.4644],
[ 0.8160, 0.5553, 0.1234, -1.1157]]), tensor([[-0.4433, -0.3093, -2.0134, -0.4277]]), tensor([[-0.4297, 0.2532, 0.2789, -0.3068],
[ 1.4208, -0.1202, 0.9256, -1.2127]]), tensor([[ 0.3542, -0.4656, 1.2683, 0.8753],
[-0.2786, -0.2180, 0.3991, 0.5658]]))
