tensor的拼接與拆分


tensor的拼接與拆分

cat函數

例子:成績單的合並

【班級1~4 學生 得分】

【班級5~9 學生 得分】

在0維進行合並,非cat的維度必須一致

a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
c = torch.cat([a,b],dim=0)
c.shape()
#[9,32,8]

stack函數

會新添加一個維度,要保證兩個stack的tensor的維度一摸一樣,在理解方面是添加了新的概念在里面。

例子:

一班:【32個學生 每個學生8門課程】

二班:【32個學生 每個學生8門課程】

stack之后變為【兩個班級 每個班級32個學生 每個學生有8門課程】

a = torch.rand(32,8)
b = torch.rand(32,8)
torch.stack([a,b],dim=0).shape
#[2 32 8]

split函數

split函數按照長度來拆分

例子1:

參數說明:【1,1】表示前面的長度為1,后面的長度也是1

a = torch.rand(2,32,8)
b,c = torch.split([1,1],dim=0)
b.shape
#[1,32,8]
c.shape()
#[1,32,8]

例子2:

參數說明:【2,1】表示前面的長度為2,后面的長度為1(不規則分割的參數含義)

a = torch.rand(3,32,8)
b,c = torch.split([2,1],dim=0)
b.shape
#[2,32,8]
c.shape()
#[1,32,8]

chunk函數

根據數量來進行分割(盡量實現整除,后面除不盡的留給最后)

例子:

a = torch.rand(6,32,8)
b,c,d= torch.chunk(a,3,dim=0)
print(b.shape)
print(c.shape)
print(d.shape)

#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])

例子2:

a = torch.rand(5,32,8)
b,c,d= torch.chunk(a,3,dim=0)
print(b.shape)
print(c.shape)
print(d.shape)

#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
#torch.Size([1, 32, 8])

例子3:

a = torch.rand(5,32,8)
b,c= torch.chunk(a,2,dim=0)
print(b.shape)
print(c.shape)

#torch.Size([3, 32, 8])
#torch.Size([2, 32, 8])


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM