- torch.mul作element-wise的矩陣點乘,維數不限,可以矩陣乘標量
- 點乘都是broadcast的,可以用
torch.mul(a, b)
實現,也可以直接用*
實現。 - 當a, b維度不一致時,會自動填充到相同維度相點乘。
1 import torch 2 3 a = torch.ones(3,4) 4 print(a) 5 b = torch.Tensor([1,2,3]).reshape((3,1)) 6 print(b) 7 8 print(torch.mul(a, b))
1 tensor([[1., 1., 1., 1.], 2 [1., 1., 1., 1.], 3 [1., 1., 1., 1.]]) 4 tensor([[1.], 5 [2.], 6 [3.]]) 7 tensor([[1., 1., 1., 1.], 8 [2., 2., 2., 2.], 9 [3., 3., 3., 3.]])