目录
torch.mul(a, b)
点乘:对应位相乘,维度必须相等
返回维度与 a, b 相同
torch.mm(a, b)
矩阵相乘
如:
a: [1, 2]
b: [2, 3]
output: [1, 3]
torch.bmm(a, b)
a, b必须是3D维度的,对两个向量维度有要求。
a: [p, m, n]
b: [p, n, q]
output: [p, m, q]
可看作把第 0 维 p 提出,[m, n] 与 [n, q] 相乘(维度层面)
import torch
x = torch.rand(2,3,6)
y = torch.rand(2,6,7)
print(torch.bmm(x,y).size())
output:
torch.Size([2, 3, 7])
###############################
y = torch.rand(2,5,7) ##维度不匹配,报错
print(torch.bmm(x,y).size())
output:
Expected tensor to have size 6 at dimension 1, but got size 5 for argument #2 'batch2' (while checking arguments for bmm)
torch.matmul()
torch.matmul(input, other, out=None) → Tensor
torch.matmul
a, b 均为1D(向量)
返回两个向量的点积
import torch
x = torch.rand(2)
y = torch.rand(2)
print(torch.matmul(x,y),torch.matmul(x,y).size())
output:
tensor(0.1353) torch.Size([])
a, b 都是2D(矩阵)
按照(矩阵相乘)规则返回2D
x = torch.rand(2,4)
y = torch.rand(4,3)
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
output:
tensor([[0.9128, 0.8425, 0.7269],
[1.4441, 1.5334, 1.3273]])
torch.Size([2, 3])
维度也要对应才可以乘
a = torch.ones(4,3)
b = torch.ones(4,3)
y = torch.matmul(b,a)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x3 and 4x3)
a为1维,b为2维
- 先将1D的维度扩充到2D(1D的维数前面+1)
- 得到结果后再将此维度去掉,得到的与input的维度相同。
即使作扩充(广播)处理,a 的维度也要和 b 维度做对应关系。
import torch
x = torch.rand(4) #1D
y = torch.rand(4,3) #2D
print(x.size())
print(y.size())
# x: 扩充x =>(,4) * y:[4, 3] =>[ , 3] => [3]
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
output:
torch.Size([4])
torch.Size([4, 3])
tensor([0.9600, 0.5736, 1.0430])
torch.Size([3])
a为2维,b为1维
对应行相乘求和
import torch
x = torch.rand(4,3) # 2D
y = torch.rand(3) # 1D
print(torch.matmul(x,y).size()) #2D*1D
output:
torch.Size([4])
a, b 均为 3 维
a: [batch_size, 1, seq_len]
b: [batch_size, seq_len, bert_dim]
torch.matmul(a,b) => [batch_size, 1, bert_dim]
类似 torch.bmm:
可看作把第 0 维 batch_size 提出,[m, n] 与 [n, q] 相乘(维度层面)
作者:top_小酱油
链接:https://www.jianshu.com/p/e277f7fc67b3
来源:简书
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。