PyTorch中的矩陣乘法


1. 二維矩陣乘法 [公式]

[公式] , 其中 [公式][公式], 輸出 [公式]的維度是[公式]。該函數[公式]一般只用來計算兩個二維矩陣的矩陣乘法,而且不支持broadcast操作。

 

2. 三維帶Batch矩陣乘法 [公式]

由於神經網絡訓練一般采用mini-batch,經常輸入的是三維帶batch矩陣,所以提供 [公式],其中 [公式][公式], 輸出 [公式]的維度是 [公式]。該函數的兩個輸入必須是三維矩陣且第一維相同(表示Batch維度),不支持broadcast操作。

3. "混合"矩陣乘法 [公式]

[公式] 支持broadcast操作,使用起來比較復雜,建議參考pytorch官方文檔

 

 

 特別 ,針對多維數據 [公式]乘法,我們可以認為該 [公式]乘法使用使用兩個參數的后兩個維度來計算,其他的維度都可以認為是batch維度。假設兩個輸入的維度分別是[公式][公式],那么我們可以認為 [公式] 乘法首先是進行后兩位矩陣乘法得到[公式] ,然后分析兩個參數的batch size分別是 [公式] 和 [公式] , 可以廣播成為 [公式], 因此最終輸出的維度是 [公式]

4. 矩陣逐元素(Element-wise)乘法 [公式]

[公式],其中 [公式] 乘數可以是標量也可以是任意維度的矩陣,只要滿足最終相乘是可以broadcast的即可,即該操作是支持broadcast操作的。

  • [公式] 是標量: 例如[公式]是維度任意的矩陣, [公式](一個標量), 那么輸出一個矩陣,其中每個值是 [公式]中原值乘以 [公式], 維度保持不變。

[公式] 是矩陣: 只要 [公式] 與 [公式] 的維度可以滿足broadcast條件,就可以進行逐元素乘法操作,例如:

1 import torch
2 A = torch.randn(2,3,4)
3 B = torch.randn(3, 4)
4 print (torch.mul(A,b).shape) # 輸出 torch.size([2,3,4)

5. 兩個乘法操作符@和[公式] 

簡單來說, @ 操作符可以執行矩陣乘法操作,類似 [公式] ; 而 [公式] 乘法操作可以執行逐元素矩陣乘法,使用方法類似 [公式]

 1 import torch
 2 
 3 x=torch.ones(3,2)
 4 print(x)
 5 
 6 y=torch.ones(3,2)+2
 7 print(y)
 8 
 9 z=torch.ones(2,1)
10 print(z)
11 
12 print(x*y@z)

 

 

 

 

參考:隨筆1: PyTorch中矩陣乘法總結 - 知乎 (zhihu.com)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM