torch.bmm(), torch.mul(), torch.matmul()


torch.mul(a, b)

点乘:对应位相乘,维度必须相等
返回维度与 a, b 相同

torch.mm(a, b)

矩阵相乘
如:
a: [1, 2]
b: [2, 3]
output: [1, 3]

torch.bmm(a, b)

a, b必须是3D维度的,对两个向量维度有要求。
a: [p, m, n]
b: [p, n, q]
output: [p, m, q]
可看作把第 0 维 p 提出,[m, n] 与 [n, q] 相乘(维度层面)

import torch
x = torch.rand(2,3,6)
y = torch.rand(2,6,7)
print(torch.bmm(x,y).size())
 
output:
torch.Size([2, 3, 7])
 
###############################
y = torch.rand(2,5,7) ##维度不匹配,报错
print(torch.bmm(x,y).size())
 
output:
Expected tensor to have size 6 at dimension 1, but got size 5 for argument #2 'batch2' (while checking arguments for bmm)
torch.matmul()

torch.matmul(input, other, out=None) → Tensor

torch.matmul

a, b 均为1D(向量)

返回两个向量的点积

import torch
x = torch.rand(2)
y = torch.rand(2)
print(torch.matmul(x,y),torch.matmul(x,y).size())
output:
tensor(0.1353) torch.Size([])

a, b 都是2D(矩阵)

按照(矩阵相乘)规则返回2D

x = torch.rand(2,4)
y = torch.rand(4,3) 
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
 
output:
tensor([[0.9128, 0.8425, 0.7269],
        [1.4441, 1.5334, 1.3273]]) 
 torch.Size([2, 3])

维度也要对应才可以乘

a = torch.ones(4,3)
b = torch.ones(4,3)
y = torch.matmul(b,a)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x3 and 4x3)

a为1维,b为2维

  • 先将1D的维度扩充到2D(1D的维数前面+1)
  • 得到结果后再将此维度去掉,得到的与input的维度相同。
    即使作扩充(广播)处理,a 的维度也要和 b 维度做对应关系。
import torch
x = torch.rand(4) #1D  
y = torch.rand(4,3) #2D
print(x.size())
print(y.size())
# x: 扩充x =>(,4) * y:[4, 3] =>[ , 3] => [3]
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
 
output:
torch.Size([4])
torch.Size([4, 3])
tensor([0.9600, 0.5736, 1.0430]) 
torch.Size([3])

a为2维,b为1维

对应行相乘求和

import torch
x = torch.rand(4,3) # 2D
y = torch.rand(3) # 1D
print(torch.matmul(x,y).size()) #2D*1D

output:
torch.Size([4])

a, b 均为 3 维

a: [batch_size, 1, seq_len]
b: [batch_size, seq_len, bert_dim]
torch.matmul(a,b) => [batch_size, 1, bert_dim]
类似 torch.bmm:
可看作把第 0 维 batch_size 提出,[m, n] 与 [n, q] 相乘(维度层面)

作者:top_小酱油
链接:https://www.jianshu.com/p/e277f7fc67b3
来源:简书
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM