『PyTorch』矩陣乘法總結


1. 二維矩陣乘法 torch.mm()

torch.mm(mat1, mat2, out=None),其中mat1(\(n\times m\)),mat2(\(m\times d\)),輸出out的維度是(\(n\times d\))。

該函數一般只用來計算兩個二維矩陣的矩陣乘法,並且不支持broadcast操作。

2. 三維帶batch的矩陣乘法 torch.bmm()

由於神經網絡訓練一般采用mini-batch,經常輸入的時三維帶batch的矩陣,所以提供torch.bmm(bmat1, bmat2, out=None),其中bmat1(\(b\times n \times m\)),bmat2(\(b\times m \times d\)),輸出out的維度是(\(b \times n \times d\))。

該函數的兩個輸入必須是三維矩陣且第一維相同(表示Batch維度),不支持broadcast操作。

3. 多維矩陣乘法 torch.matmul()

torch.matmul(input, other, out=None)支持broadcast操作,使用起來比較復雜。

針對多維數據 matmul()乘法,我們可以認為該matmul()乘法使用使用兩個參數的后兩個維度來計算,其他的維度都可以認為是batch維度。假設兩個輸入的維度分別是input(\(1000 \times 500 \times 99 \times 11\)), other(\(500 \times 11 \times 99\))那么我們可以認為torch.matmul(input, other, out=None)乘法首先是進行后兩位矩陣乘法得到\((99 \times 11) \times (11 \times 99)\Rightarrow(99 \times 99)\) ,然后分析兩個參數的batch size分別是 \(( 1000 \times 500)\)\(500\) , 可以廣播成為 \((1000 \times 500)\), 因此最終輸出的維度是(\(1000 \times 500 \times 99 \times 99\))。

4. 矩陣逐元素(Element-wise)乘法 torch.mul()

torch.mul(mat1, other, out=None),其中other乘數可以是標量,也可以是任意維度的矩陣,只要滿足最終相乘是可以broadcast的即可

5. 兩個運算符 @ 和 *

  • @:矩陣乘法,自動執行適合的矩陣乘法函數
  • *element-wise乘法


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM