tensor的拼接與拆分
cat函數
例子:成績單的合並
【班級1~4 學生 得分】
【班級5~9 學生 得分】
在0維進行合並,非cat的維度必須一致
a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
c = torch.cat([a,b],dim=0)
c.shape()
#[9,32,8]
stack函數
會新添加一個維度,要保證兩個stack的tensor的維度一摸一樣
,在理解方面是添加了新的概念在里面。
例子:
一班:【32個學生 每個學生8門課程】
二班:【32個學生 每個學生8門課程】
stack之后變為【兩個班級 每個班級32個學生 每個學生有8門課程】
a = torch.rand(32,8)
b = torch.rand(32,8)
torch.stack([a,b],dim=0).shape
#[2 32 8]
split函數
split函數按照長度來拆分
例子1:
參數說明:【1,1】表示前面的長度為1,后面的長度也是1
a = torch.rand(2,32,8)
b,c = torch.split([1,1],dim=0)
b.shape
#[1,32,8]
c.shape()
#[1,32,8]
例子2:
參數說明:【2,1】表示前面的長度為2,后面的長度為1
(不規則分割的參數含義)
a = torch.rand(3,32,8)
b,c = torch.split([2,1],dim=0)
b.shape
#[2,32,8]
c.shape()
#[1,32,8]
chunk函數
根據數量來進行分割(盡量實現整除,后面除不盡的留給最后)
例子:
a = torch.rand(6,32,8)
b,c,d= torch.chunk(a,3,dim=0)
print(b.shape)
print(c.shape)
print(d.shape)
#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
例子2:
a = torch.rand(5,32,8)
b,c,d= torch.chunk(a,3,dim=0)
print(b.shape)
print(c.shape)
print(d.shape)
#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
#torch.Size([1, 32, 8])
例子3:
a = torch.rand(5,32,8)
b,c= torch.chunk(a,2,dim=0)
print(b.shape)
print(c.shape)
#torch.Size([3, 32, 8])
#torch.Size([2, 32, 8])