torch.bmm(), torch.mul(), torch.matmul()


torch.mul(a, b)

點乘:對應位相乘,維度必須相等
返回維度與 a, b 相同

torch.mm(a, b)

矩陣相乘
如:
a: [1, 2]
b: [2, 3]
output: [1, 3]

torch.bmm(a, b)

a, b必須是3D維度的,對兩個向量維度有要求。
a: [p, m, n]
b: [p, n, q]
output: [p, m, q]
可看作把第 0 維 p 提出,[m, n] 與 [n, q] 相乘(維度層面)

import torch
x = torch.rand(2,3,6)
y = torch.rand(2,6,7)
print(torch.bmm(x,y).size())
 
output:
torch.Size([2, 3, 7])
 
###############################
y = torch.rand(2,5,7) ##維度不匹配,報錯
print(torch.bmm(x,y).size())
 
output:
Expected tensor to have size 6 at dimension 1, but got size 5 for argument #2 'batch2' (while checking arguments for bmm)
torch.matmul()

torch.matmul(input, other, out=None) → Tensor

torch.matmul

a, b 均為1D(向量)

返回兩個向量的點積

import torch
x = torch.rand(2)
y = torch.rand(2)
print(torch.matmul(x,y),torch.matmul(x,y).size())
output:
tensor(0.1353) torch.Size([])

a, b 都是2D(矩陣)

按照(矩陣相乘)規則返回2D

x = torch.rand(2,4)
y = torch.rand(4,3) 
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
 
output:
tensor([[0.9128, 0.8425, 0.7269],
        [1.4441, 1.5334, 1.3273]]) 
 torch.Size([2, 3])

維度也要對應才可以乘

a = torch.ones(4,3)
b = torch.ones(4,3)
y = torch.matmul(b,a)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x3 and 4x3)

a為1維,b為2維

  • 先將1D的維度擴充到2D(1D的維數前面+1)
  • 得到結果后再將此維度去掉,得到的與input的維度相同。
    即使作擴充(廣播)處理,a 的維度也要和 b 維度做對應關系。
import torch
x = torch.rand(4) #1D  
y = torch.rand(4,3) #2D
print(x.size())
print(y.size())
# x: 擴充x =>(,4) * y:[4, 3] =>[ , 3] => [3]
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
 
output:
torch.Size([4])
torch.Size([4, 3])
tensor([0.9600, 0.5736, 1.0430]) 
torch.Size([3])

a為2維,b為1維

對應行相乘求和

import torch
x = torch.rand(4,3) # 2D
y = torch.rand(3) # 1D
print(torch.matmul(x,y).size()) #2D*1D

output:
torch.Size([4])

a, b 均為 3 維

a: [batch_size, 1, seq_len]
b: [batch_size, seq_len, bert_dim]
torch.matmul(a,b) => [batch_size, 1, bert_dim]
類似 torch.bmm:
可看作把第 0 維 batch_size 提出,[m, n] 與 [n, q] 相乘(維度層面)

作者:top_小醬油
鏈接:https://www.jianshu.com/p/e277f7fc67b3
來源:簡書
著作權歸作者所有。商業轉載請聯系作者獲得授權,非商業轉載請注明出處。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM