目錄
torch.mul(a, b)
點乘:對應位相乘,維度必須相等
返回維度與 a, b 相同
torch.mm(a, b)
矩陣相乘
如:
a: [1, 2]
b: [2, 3]
output: [1, 3]
torch.bmm(a, b)
a, b必須是3D維度的,對兩個向量維度有要求。
a: [p, m, n]
b: [p, n, q]
output: [p, m, q]
可看作把第 0 維 p 提出,[m, n] 與 [n, q] 相乘(維度層面)
import torch
x = torch.rand(2,3,6)
y = torch.rand(2,6,7)
print(torch.bmm(x,y).size())
output:
torch.Size([2, 3, 7])
###############################
y = torch.rand(2,5,7) ##維度不匹配,報錯
print(torch.bmm(x,y).size())
output:
Expected tensor to have size 6 at dimension 1, but got size 5 for argument #2 'batch2' (while checking arguments for bmm)
torch.matmul()
torch.matmul(input, other, out=None) → Tensor
torch.matmul
a, b 均為1D(向量)
返回兩個向量的點積
import torch
x = torch.rand(2)
y = torch.rand(2)
print(torch.matmul(x,y),torch.matmul(x,y).size())
output:
tensor(0.1353) torch.Size([])
a, b 都是2D(矩陣)
按照(矩陣相乘)規則返回2D
x = torch.rand(2,4)
y = torch.rand(4,3)
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
output:
tensor([[0.9128, 0.8425, 0.7269],
[1.4441, 1.5334, 1.3273]])
torch.Size([2, 3])
維度也要對應才可以乘
a = torch.ones(4,3)
b = torch.ones(4,3)
y = torch.matmul(b,a)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x3 and 4x3)
a為1維,b為2維
- 先將1D的維度擴充到2D(1D的維數前面+1)
- 得到結果后再將此維度去掉,得到的與input的維度相同。
即使作擴充(廣播)處理,a 的維度也要和 b 維度做對應關系。
import torch
x = torch.rand(4) #1D
y = torch.rand(4,3) #2D
print(x.size())
print(y.size())
# x: 擴充x =>(,4) * y:[4, 3] =>[ , 3] => [3]
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
output:
torch.Size([4])
torch.Size([4, 3])
tensor([0.9600, 0.5736, 1.0430])
torch.Size([3])
a為2維,b為1維
對應行相乘求和
import torch
x = torch.rand(4,3) # 2D
y = torch.rand(3) # 1D
print(torch.matmul(x,y).size()) #2D*1D
output:
torch.Size([4])
a, b 均為 3 維
a: [batch_size, 1, seq_len]
b: [batch_size, seq_len, bert_dim]
torch.matmul(a,b) => [batch_size, 1, bert_dim]
類似 torch.bmm:
可看作把第 0 維 batch_size 提出,[m, n] 與 [n, q] 相乘(維度層面)
作者:top_小醬油
鏈接:https://www.jianshu.com/p/e277f7fc67b3
來源:簡書
著作權歸作者所有。商業轉載請聯系作者獲得授權,非商業轉載請注明出處。