1. 二維矩陣乘法 torch.mm()
torch.mm(mat1, mat2, out=None)
,其中mat1
(\(n\times m\)),mat2
(\(m\times d\)),輸出out
的維度是(\(n\times d\))。
該函數一般只用來計算兩個二維矩陣的矩陣乘法,並且不支持broadcast操作。
2. 三維帶batch的矩陣乘法 torch.bmm()
由於神經網絡訓練一般采用mini-batch,經常輸入的時三維帶batch的矩陣,所以提供torch.bmm(bmat1, bmat2, out=None)
,其中bmat1
(\(b\times n \times m\)),bmat2
(\(b\times m \times d\)),輸出out
的維度是(\(b \times n \times d\))。
該函數的兩個輸入必須是三維矩陣且第一維相同(表示Batch維度),不支持broadcast操作。
3. 多維矩陣乘法 torch.matmul()
torch.matmul(input, other, out=None)
支持broadcast操作,使用起來比較復雜。
針對多維數據 matmul()
乘法,我們可以認為該matmul()
乘法使用使用兩個參數的后兩個維度來計算,其他的維度都可以認為是batch維度。假設兩個輸入的維度分別是input
(\(1000 \times 500 \times 99 \times 11\)), other
(\(500 \times 11 \times 99\))那么我們可以認為torch.matmul(input, other, out=None)
乘法首先是進行后兩位矩陣乘法得到\((99 \times 11) \times (11 \times 99)\Rightarrow(99 \times 99)\) ,然后分析兩個參數的batch size分別是 \(( 1000 \times 500)\) 和 \(500\) , 可以廣播成為 \((1000 \times 500)\), 因此最終輸出的維度是(\(1000 \times 500 \times 99 \times 99\))。
4. 矩陣逐元素(Element-wise)乘法 torch.mul()
torch.mul(mat1, other, out=None)
,其中other
乘數可以是標量,也可以是任意維度的矩陣,只要滿足最終相乘是可以broadcast的即可
5. 兩個運算符 @ 和 *
@
:矩陣乘法,自動執行適合的矩陣乘法函數*
:element-wise乘法