張量的拼接


張量的拼接有兩種主要的基本策略:

不增加張量的維度,只增加某個維度方向的長度:cat()
增加張量的維度,不增加單個維度方向的長度:stack()
第2章 增加張量長度的拼接:cat()
2.1 基本原理

 

 

 

2.2 函數說明
功能:在不改變張量維度的情況下,通過增加張量在某個維度方向的長度,把兩個張量拼接起來。

原型:cat(input, dim)

輸入參數:

input: 輸入張量

dim:拼接的方向

2.3 代碼示例
(1)按照dim =0 的方向拼接

# 張量的拼接:階數不變,增加dim方向的長度
a = torch.Tensor([[1,1,1,1], [2,2,2,2],[3,3,3,3]])
b = torch.Tensor([[4,4,4,4], [5,5,5,5],[5,5,5,5]])
print("源張量")
print(a)
print(a.shape)
print(b)
print(b.shape)

print("\n按照dim=0方向拼接")
c = torch.cat((a,b),dim=0)
print(c)
print(c.shape)
輸出:

源張量
tensor([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]])
torch.Size([3, 4])
tensor([[4., 4., 4., 4.],
[5., 5., 5., 5.],
[5., 5., 5., 5.]])
torch.Size([3, 4])

按照dim=0方向拼接


tensor([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.],
[4., 4., 4., 4.],
[5., 5., 5., 5.],
[5., 5., 5., 5.]])
torch.Size([6, 4])
(2)按照dim=1的方向進行拼接

# 張量的拼接:增加階數,
a = torch.Tensor([[1,1,1,1], [2,2,2,2],[3,3,3,3]])
b = torch.Tensor([[4,4,4,4], [5,5,5,5],[5,5,5,5]])
print("源張量")
print(a)
print(a.shape)
print(b)
print(b.shape)

print("\n按照dim=0方向拼接:擴展階數")
c = torch.stack((a,b),dim=0)
print(c)
print(c.shape)
輸出:

源張量
tensor([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]])
torch.Size([3, 4])
tensor([[4., 4., 4., 4.],
[5., 5., 5., 5.],
[5., 5., 5., 5.]])
torch.Size([3, 4])

按照dim=1方向拼接
tensor([[1., 1., 1., 1., 4., 4., 4., 4.],
[2., 2., 2., 2., 5., 5., 5., 5.],
[3., 3., 3., 3., 5., 5., 5., 5.]])
torch.Size([3, 8])
第3章 增加張量維度的拼接:stack()
3.1 基本原理
stack堆疊,可以增加一個維度,因此案例中dim=0,1,2 三種情形,在三種方向進行堆疊。

(1)按照dim = 0的方向堆疊

 

 

 

 

(2)按照dim = 1的方向堆疊

 

 

 

(3)按照dim = 2的方向堆疊

 

 

 

2.2 函數說明
功能:通過增加張量維度,把兩個張量堆疊起來,堆疊后,維度增加1.

原型:stack(input, dim)

輸入參數:

input: 輸入張量

dim:拼接的方向,這里的dim是指拼接后張量的dim,而不是原張量的dim

2.3 代碼示例
(1)dim=0的方向

# 張量的拼接:增加階數,
a = torch.Tensor([[1,1,1,1], [2,2,2,2],[3,3,3,3]])
b = torch.Tensor([[4,4,4,4], [5,5,5,5],[5,5,5,5]])
print("源張量")
print(a)
print(a.shape)
print(b)
print(b.shape)

print("\n按照dim=0方向拼接:擴展階數")
c = torch.stack((a,b),dim=0)
print(c)
print(c.shape)
輸出:

源張量
tensor([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]])
torch.Size([3, 4])
tensor([[4., 4., 4., 4.],
[5., 5., 5., 5.],
[5., 5., 5., 5.]])
torch.Size([3, 4])

按照dim=0方向拼接:擴展階數
tensor([[[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]],

[[4., 4., 4., 4.],
[5., 5., 5., 5.],
[5., 5., 5., 5.]]])
torch.Size([2, 3, 4])
(2)dim=1的方向

a = torch.Tensor([[1,1,1,1], [2,2,2,2],[3,3,3,3]])
b = torch.Tensor([[4,4,4,4], [5,5,5,5],[5,5,5,5]])
print("源張量")
print(a)
print(a.shape)
print(b)
print(b.shape)

print("\n按照dim=1方向拼接:擴展階數")
c = torch.stack((a,b),dim=1)
print(c)
print(c.shape)
輸出:

源張量
tensor([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]])
torch.Size([3, 4])
tensor([[4., 4., 4., 4.],
[5., 5., 5., 5.],
[5., 5., 5., 5.]])
torch.Size([3, 4])

按照dim=1方向拼接:擴展階數
tensor([[[1., 1., 1., 1.],
[4., 4., 4., 4.]],

[[2., 2., 2., 2.],
[5., 5., 5., 5.]],

[[3., 3., 3., 3.],
[5., 5., 5., 5.]]])
torch.Size([3, 2, 4])
(3)dim=2的方向

a = torch.Tensor([[1,1,1,1], [2,2,2,2],[3,3,3,3]])
b = torch.Tensor([[4,4,4,4], [5,5,5,5],[5,5,5,5]])
print("源張量")
print(a)
print(a.shape)
print(b)
print(b.shape)

print("\n按照dim=2方向拼接:擴展階數")
c = torch.stack((a,b),dim=2)
print(c)
print(c.shape)
輸出:

源張量
tensor([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]])
torch.Size([3, 4])
tensor([[4., 4., 4., 4.],
[5., 5., 5., 5.],
[5., 5., 5., 5.]])
torch.Size([3, 4])

按照dim=2方向拼接:擴展階數
tensor([[[1., 4.],
[1., 4.],
[1., 4.],
[1., 4.]],

[[2., 5.],
[2., 5.],
[2., 5.],
[2., 5.]],

[[3., 5.],
[3., 5.],
[3., 5.],
[3., 5.]]])
torch.Size([3, 4, 2])


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM