Pytorch中的torch.cat()函數


cat是concatnate的意思:拼接,聯系在一起。


 

先說cat( )的普通用法

如果我們有兩個tensor是A和B,想把他們拼接在一起,需要如下操作:

C = torch.cat( (A,B),0 )  #按維數0拼接(豎着拼)

C = torch.cat( (A,B),1 )  #按維數1拼接(橫着拼)
復制代碼
>>> import torch
>>> A=torch.ones(2,3)    #2x3的張量(矩陣)                                     
>>> A
tensor([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.]])
>>> B=2*torch.ones(4,3)  #4x3的張量(矩陣)                                    
>>> B
tensor([[ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.]])
>>> C=torch.cat((A,B),0)  #按維數0(行)拼接
>>> C
tensor([[ 1.,  1.,  1.],
         [ 1.,  1.,  1.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.]])
>>> C.size()
torch.Size([6, 3])
>>> D=2*torch.ones(2,4) #2x4的張量(矩陣)
>>> C=torch.cat((A,D),1)#按維數1(列)拼接
>>> C
tensor([[ 1.,  1.,  1.,  2.,  2.,  2.,  2.],
        [ 1.,  1.,  1.,  2.,  2.,  2.,  2.]])
>>> C.size()
torch.Size([2, 7])
復制代碼

其次,cat還可以把list中的tensor拼接起來。

比如:

上面的代碼可以合成一行來寫:

 

 
 
標簽:  深度學習Pytorch


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM