1. 二維矩陣乘法 ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD10b3JjaC5tbSUyOCUyOQ==.png)
, 其中
,
, 輸出
的維度是
。該函數
一般只用來計算兩個二維矩陣的矩陣乘法,而且不支持broadcast操作。
2. 三維帶Batch矩陣乘法 ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD10b3JjaC5ibW0lMjglMjk=.png)
由於神經網絡訓練一般采用mini-batch,經常輸入的是三維帶batch矩陣,所以提供 ,其中
,
, 輸出
的維度是
。該函數的兩個輸入必須是三維矩陣且第一維相同(表示Batch維度),不支持broadcast操作。
3. "混合"矩陣乘法 ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD10b3JjaC5tYXRtdWwlMjglMjk=.png)
支持broadcast操作,使用起來比較復雜,建議參考pytorch官方文檔。
特別 ,針對多維數據 乘法,我們可以認為該
乘法使用使用兩個參數的后兩個維度來計算,其他的維度都可以認為是batch維度。假設兩個輸入的維度分別是
,
,那么我們可以認為
乘法首先是進行后兩位矩陣乘法得到
,然后分析兩個參數的batch size分別是
和
, 可以廣播成為
, 因此最終輸出的維度是
。
4. 矩陣逐元素(Element-wise)乘法 ![[公式]](/image/aHR0cHM6Ly93d3cuemhpaHUuY29tL2VxdWF0aW9uP3RleD10b3JjaC5tdWwlMjglMjk=.png)
,其中
乘數可以是標量也可以是任意維度的矩陣,只要滿足最終相乘是可以broadcast的即可,即該操作是支持broadcast操作的。
是標量: 例如
是維度任意的矩陣,
(一個標量), 那么輸出一個矩陣,其中每個值是
中原值乘以
, 維度保持不變。
是矩陣: 只要
與
的維度可以滿足broadcast條件,就可以進行逐元素乘法操作,例如:
1 import torch 2 A = torch.randn(2,3,4) 3 B = torch.randn(3, 4) 4 print (torch.mul(A,b).shape) # 輸出 torch.size([2,3,4)
5. 兩個乘法操作符@和
簡單來說, @ 操作符可以執行矩陣乘法操作,類似 ; 而
乘法操作可以執行逐元素矩陣乘法,使用方法類似
。
1 import torch 2 3 x=torch.ones(3,2) 4 print(x) 5 6 y=torch.ones(3,2)+2 7 print(y) 8 9 z=torch.ones(2,1) 10 print(z) 11 12 print(x*y@z)