Pytorch張量高階操作


1.Broadcasting

Broadcasting能夠實現Tensor自動維度增加(unsqueeze)與維度擴展(expand),以使兩個Tensor的shape一致,從而完成某些操作,主要按照如下步驟進行:

  • 從最后面的維度開始匹配(一般后面理解為小維度);
  • 前面插入若干維度,進行unsqueeze操作;
  • 將維度的size從1通過expand變到和某個Tensor相同的維度。

舉例:

Feature maps:[4, 32, 14, 14]

Bias:[32, 1, 1](Tip:后面的兩個1是手動unsqueeze插入的維度)->[1, 32, 1, 1]->[4, 32, 14, 14]

32

匹配規則(從最后面的維度開始匹配):

  • if current dim=1,expand to same
  • if either has no dim,insert one dim and expand to same
  • otherwise,NOT broadcasting-able

A的維度[4, 32, 8],B的維度[1],[1]->[1, 1, 1]->[4, 32, 8],對應情況1

A的維度[4, 32, 8],B的維度[8],[1]->[1, 1, 8]->[4, 32, 8],對應情況2

A的維度[4, 32, 8],B的維度[4],對應情況3,不能broadcasting

2.拼接與拆分

cat拼接操作

  • 功能:通過dim指定維度,在當前指定維度上直接拼接
  • 默認是dim=0
  • 指定的dim上,維度可以不相同,其他dim上維度必須相同,不然會報錯
1 a1=torch.rand(4,3,32,32)
2 a2=torch.rand(5,3,32,32)
3 print(torch.cat([a1,a2],dim=0).shape)    #torch.Size([9, 3, 32, 32])
4 
5 a3=torch.rand(4,1,32,32)
6 print(torch.cat([a1,a3],dim=1).shape)    #torch.Size([4, 4, 32, 32])
7 
8 a4=torch.rand(4,3,16,32)
9 print(torch.cat([a1,a4],dim=2).shape)    #torch.Size([4, 3, 48, 32])

stack拼接操作

  • 與cat不同的是,stack是在拼接的同時,在指定dim處插入維度后拼接(create new dim
  • stack需要保證兩個Tensor的shape是一致的,這就像是有兩類東西,它們的其它屬性都是一樣的(比如男的一張表,女的一張表)。使用stack時候要指定一個維度位置,在那個位置前會插入一個新的維度,因為是兩類東西合並過來所以這個新的維度size是2,通過指定這個維度是0或者1來選擇性別是男還是女。

  • 默認dim=0
1 a1=torch.rand(4,3,32,32)
2 a2=torch.rand(4,3,32,32)
3 print(torch.stack([a1,a2],dim=1).shape)  #torch.Size([4, 2, 3, 32, 32])  
左邊起第二個維度取0時,取上半部分即a1,左邊起第二個維度取1時,取下半部分即a2
4 print(torch.stack([a1,a2],dim=2).shape) #torch.Size([4, 3, 2, 32, 32])

split分割操作

  • 指定拆分dim
  • 按長度拆分,給定拆分后的數據大小
1 c=torch.rand(3,32,8)
2 
3 aa,bb=c.split([1,2],dim=0)     
4 print(aa.shape,bb.shape)            #torch.Size([1, 32, 8]) torch.Size([2, 32, 8])
5 
6 aa,bb,cc=c.split([1,1,1],dim=0)     #或者寫成aa,bb,cc=c.split(1,dim=0) 
7 print(aa.shape,bb.shape,cc.shape)   #torch.Size([1, 32, 8]) torch.Size([1, 32, 8]) torch.Size([1, 32, 8])

chunk分割操作

  • chunk是在指定dim下按個數拆分,給定平均拆分的個數
  • 如果給定個數不能平均拆分當前維度,則會取比給定個數小的,能平均拆分數據的,最大的個數
  • dim默認是0
1 c=torch.rand(3,32,8)
2 d=torch.rand(2,32,8)
3 aa,bb=d.chunk(2,dim=0)
4 print(aa.shape,bb.shape)            #torch.Size([1, 32, 8]) torch.Size([1, 32, 8])
5 
6 aa,bb=c.chunk(2,dim=0)
7 print(aa.shape,bb.shape)            #torch.Size([2, 32, 8]) torch.Size([1, 32, 8])

3.基本運算

加法(a+b、torch.add(a,b))

減法(a-b、torch.sub(a,b))

乘法(*、torch.mul(a,b))對應元素相乘

除法(/、torch.div(a,b))對應元素相除,//整除

1 a = torch.rand(3, 4)
2 b = torch.rand(4)
3  
4 c1 = a + b
5 c2 = torch.add(a, b)
6 print(c1.shape, c2.shape)                #torch.Size([3, 4]) torch.Size([3, 4])
7 print(torch.all(torch.eq(c1, c2)))       #tensor(True)

矩陣乘法

torch.mm(only for 2d,不推薦使用)

torch.matmul(推薦)

@

1 a=torch.rand(2,1)
2 b=torch.rand(1,2)
3 print(torch.mm(a,b).shape)          #torch.Size([2, 2])
4 print(torch.matmul(a,b).shape)      #torch.Size([2, 2])
5 print((a@b).shape)                  #torch.Size([2, 2])

應用於矩陣降維

1 x=torch.rand(4,784)
2 w=torch.rand(512,784)             #channel-out對應512,channel-in對應784
3 print((x@w.t()).shape)            #torch.Size([4, 512]) Tip:.t()只適用於二維

多維矩陣相乘

對於高維的Tensor(dim>2),定義其矩陣乘法僅在最后的兩個維度上,要求前面的維度必須保持一致,就像矩陣的索引一樣並且運算操作符只有torch.matmul()。

1 a=torch.rand(4,3,28,64)
2 b=torch.rand(4,3,64,32)
3 print(torch.matmul(a,b).shape)    #torch.Size([4, 3, 28, 32])
4        
5 c=torch.rand(4, 1, 64, 32)
6 print(torch.matmul(a,c).shape)    #torch.Size([4, 3, 28, 32])
7 
8 d=torch.rand(4,64,32)
9 print(torch.matmul(a,d).shape)    #報錯

Tip:這種情形下的矩陣相乘,"矩陣索引維度"如果符合Broadcasting機制,也會自動做廣播,然后相乘。

次方pow、**操作

1 a = torch.full([2, 2], 3)  
2 b = a.pow(2)                 #也可以a**2
3 print(b)
4 #tensor([[9., 9.],
5 #        [9., 9.]])

開方sqrt、**操作

1 #接上面
2 c = b.sqrt()   #也可以a**(0.5)
3 print(c)
4 #tensor([[3., 3.],
5 #        [3., 3.]])
6 d = b.rsqrt()  #平方根的倒數
7 print(d)
8 #tensor([[0.3333, 0.3333],
9 #        [0.3333, 0.3333]])

指數exp與對數log運算

log是以自然對數為底數的,以2為底的用log2,以10為底的用log10。

1 a = torch.exp(torch.ones(2, 2))  #得到2*2的全是e的Tensor
2 print(a)
3 #tensor([[2.7183, 2.7183],
4 #        [2.7183, 2.7183]])
5 print(torch.log(a))              #取自然對數
6 #tensor([[1., 1.],
7 #        [1., 1.]])

近似值運算

1 a = torch.tensor(3.14)
2 print(a.floor(), a.ceil(), a.trunc(), a.frac())  #取下,取上,取整數,取小數
3 #tensor(3.) tensor(4.) tensor(3.) tensor(0.1400)
4 b = torch.tensor(3.49)
5 c = torch.tensor(3.5)
6 print(b.round(), c.round())                      #四舍五入tensor(3.) tensor(4.)

裁剪運算clamp

對Tensor中的元素進行范圍過濾,不符合條件的可以把它變換到范圍內部(邊界)上,常用於梯度裁剪(gradient clipping),即在發生梯度離散或者梯度爆炸時對梯度的處理,實際使用時可以查看梯度的(L2范數)模來看看需不需要做處理:w.grad.norm(2)

 1 grad = torch.rand(2, 3) * 15      #0~15隨機生成
 2 print(grad.max(), grad.min(), grad.median())  #tensor(12.9533) tensor(1.5625) tensor(11.1101)
 3  
 4 print(grad)
 5 #tensor([[12.7630, 12.9533,  7.6125],
 6 #        [11.1101, 12.4215,  1.5625]])
 7 print(grad.clamp(10))             #最小是10,小於10的都變成10
 8 #tensor([[12.7630, 12.9533, 10.0000],
 9 #        [11.1101, 12.4215, 10.0000]])
10 print(grad.clamp(3, 10))          #最小是3,小於3的都變成3;最大是10,大於10的都變成10
11 #tensor([[10.0000, 10.0000,  7.6125],
12 #        [10.0000, 10.0000,  3.0000]])

4.統計屬性

范數norm

Vector norm 和matrix norm區別

 1 a=torch.full([8],1)
 2 b=a.view(2,4)
 3 c=a.view(2,2,2)
 4 print(b)
 5 #tensor([[1., 1., 1., 1.],
 6 #        [1., 1., 1., 1.]])
 7 print(c)
 8 #tensor([[[1., 1.],
 9 #         [1., 1.]],
10 #        [[1., 1.],
11 #         [1., 1.]]])
12 
13 #求L1范數(所有元素絕對值求和)
14 print(a.norm(1),b.norm(1),c.norm(1))            #tensor(8.) tensor(8.) tensor(8.)
15 #求L2范數(所有元素的平方和再開根)
16 print(a.norm(2),b.norm(2),c.norm(2))            #tensor(2.8284) tensor(2.8284) tensor(2.8284)
17 
18 # 在b的1號維度上求L1范數
19 print(b.norm(1, dim=1))            #tensor([4., 4.])
20 # 在b的1號維度上求L2范數
21 print(b.norm(2, dim=1))            #tensor([2., 2.])
22  
23 # 在c的0號維度上求L1范數
24 print(c.norm(1, dim=0))
25 #tensor([[2., 2.],
26 #        [2., 2.]])
27 # 在c的0號維度上求L2范數
28 print(c.norm(2, dim=0))
29 #tensor([[1.4142, 1.4142],
30 #        [1.4142, 1.4142]])

均值mean、累加sum、最小min、最大max、累積prod

最大值最小值索引argmax、argmin

 1 b = torch.arange(8).reshape(2, 4).float()
 2 print(b)
 3 #均值,累加,最小,最大,累積
 4 print(b.mean(), b.sum(), b.min(), b.max(), b.prod())       #tensor(3.5000) tensor(28.) tensor(0.) tensor(7.) tensor(0.)  
 5 
 6 #不指定維度,輸出打平后的最小最大值索引
 7 print(b.argmax(), b.argmin())                              #tensor(7) tensor(0)
 8 #指定維度1,輸出每一行最大值所在的索引
 9 print(b.argmax(dim=1))                                     #tensor([3, 3])
10 #指定維度0,輸出每一列最大值所在的索引
11 print(b.argmax(dim=0))                                     #tensor([1, 1, 1, 1])

Tip:上面的argmax、argmin操作默認會將Tensor打平后取最大值索引和最小值索引,如果不希望Tenosr打平,而是求給定維度上的索引,需要指定在哪一個維度上求最大值或最小值索引。

dim、keepdim

比方說shape=[4,10],dim=1時,保留第0個維度,即max輸出會有4個值。

 1 a=torch.rand(4,10)
 2 print(a.max(dim=1))                                      #返回結果和索引
 3 # torch.return_types.max(
 4 # values=tensor([0.9770, 0.8467, 0.9866, 0.9064]),
 5 # indices=tensor([4, 2, 2, 4]))
 6 print(a.argmax(dim=1))                                   #tensor([4, 2, 2, 4])
 7 
 8 print(a.max(dim=1,keepdim=True))
 9 # torch.return_types.max(
10 # values=tensor([[0.9770],
11 #         [0.8467],
12 #         [0.9866],
13 #         [0.9064]]),
14 # indices=tensor([[4],
15 #         [2],
16 #         [2],
17 #         [4]]))
18 print(a.argmax(dim=1,keepdim=True))
19 # tensor([[4],
20 #         [2],
21 #         [2],
22 #         [4]])

Tip:使用keepdim=True可以保持應有的dim,即僅僅是將求最值的那個dim的size變成了1,返回的結果是符合原Tensor語義的。

取前k大topk(largest=True)/前k小(largest=False)的概率值及其索引

第k小(kthvalue)的概率值及其索引

 1 # 2個樣本,分為10個類別的置信度
 2 d = torch.randn(2, 10)  
 3 # 最大概率的3個類別
 4 print(d.topk(3, dim=1))  
 5 # torch.return_types.topk(
 6 # values=tensor([[1.6851, 1.5693, 1.5647],
 7 #                [0.8492, 0.4311, 0.3302]]),
 8 # indices=tensor([[9, 1, 4],
 9 #                 [6, 2, 4]]))
10 
11 # 最小概率的3個類別
12 print(d.topk(3, dim=1, largest=False))  
13 # torch.return_types.topk(
14 # values=tensor([[-1.2583, -0.7617, -0.4518],
15 #         [-1.5011, -0.9987, -0.9042]]),
16 # indices=tensor([[6, 7, 2],
17 #         [3, 1, 9]]))
18 
19 # 求第8小概率的類別(一共10個那就是第3大,正好對應上面最大概率的3個類別的第3列)
20 print(d.kthvalue(8, dim=1))  
21 # torch.return_types.kthvalue(
22 # values=tensor([1.5647, 0.3302]),
23 # indices=tensor([4, 4]))

比較操作

>,>=,<,<=,!=,==

torch.eq(a,b)、torch.equal(a,b)

 1 a=torch.randn(2,3)
 2 b=torch.randn(2,3)
 3 print(a>0)
 4 print(torch.gt(a,0))
 5 # tensor([[False,  True,  True],
 6 #         [True, False, False]])
 7 
 8 
 9 print(torch.equal(a,a))        #True
10 print(torch.eq(a,a))
11 # tensor([[True, True, True],
12 #         [True, True, True]])

5.高階操作

where

使用C=torch.where(condition,A,B)其中A,B,C,condition是shape相同的Tensor,C中的某些元素來自A,某些元素來自B,這由condition中對應位置的元素是1還是0來決定。如果condition對應位置元素是1,則C中的該位置的元素來自A中的該位置的元素,如果condition對應位置元素是0,則C中的該位置的元素來自B中的該位置的元素。

 1 cond=torch.tensor([[0.6,0.1],[0.8,0.7]])
 2 
 3 a=torch.tensor([[1,2],[3,4]])
 4 b=torch.tensor([[4,5],[6,7]])
 5 print(cond>0.5)
 6 # tensor([[ True, False],
 7 #         [ True,  True]])
 8 print(torch.where(cond>0.5,a,b))
 9 # tensor([[1, 5],
10 #         [3, 4]])

gather

torch.gather(input, dim, index, out=None)對元素實現一個查表映射的操作:

 1 prob=torch.randn(4,10)
 2 idx=prob.topk(dim=1,k=3)
 3 print(idx)
 4 # torch.return_types.topk(
 5 # values=tensor([[ 1.6009,  0.7975,  0.6671],
 6 #                [ 1.0937,  0.9888,  0.7749],
 7 #                [ 1.1727,  0.6124, -0.3543],
 8 #                [ 1.1406,  0.8465,  0.6256]]),
 9 # indices=tensor([[8, 5, 4],
10 #                 [6, 9, 8],
11 #                 [1, 3, 8],
12 #                 [6, 1, 3]]))
13 idx=idx[1]
14 
15 label=torch.arange(10)+100
16 print(torch.gather(label.expand(4,10), dim=1, index=idx.long()))
17 # tensor([[108, 105, 104],
18 #         [106, 109, 108],
19 #         [101, 103, 108],
20 #         [106, 101, 103]])

label=[[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],

      [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],

           [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],

           [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]]

index=[[8, 5, 4],

       [6, 9, 8],

            [1, 3, 8],

            [6, 1, 3]]

gather的含義就是利用index來索引input特定位置的數值。

補充scatter_

scatter_(dim, index, src)將src中數據根據index中的索引按照dim的方向填進input中

細節再補充。。。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM