Math operation 數學運算
- Add/minus/multiply/divide
- Matmul
- Pow
- Sqrt/rsqrt
- Round
加減乘除
>>> a=torch.rand(3,4) >>> b=torch.rand(3) >>> a+b Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 1 >>> >>> b=torch.rand(4) >>> a+b tensor([[1.0229, 1.0478, 1.5048, 0.1652], [0.7158, 1.5854, 0.9578, 1.0546], [0.8643, 1.6602, 1.1447, 0.4042]]) >>> a-b tensor([[ 0.8146, -0.9272, -0.3113, -0.1604], [ 0.5075, -0.3895, -0.8584, 0.7290], [ 0.6560, -0.3147, -0.6715, 0.0786]]) >>> torch.all(torch.eq(a-b,torch.sub(a,b))) tensor(True) >>> >>> torch.all(torch.eq(a+b,torch.add(a,b))) tensor(True) >>> torch.all(torch.eq(a*b,torch.mul(a,b))) tensor(True) >>> torch.all(torch.eq(a/b,torch.div(a,b))) tensor(True)
matmul矩陣乘積
只能對 2d張量(矩陣)進行使用。
>>> a tensor([[0.5370, 0.6411], [0.6101, 0.4873]]) >>> b tensor([[0.6591, 0.9531], [0.9422, 0.0774]]) >>> a*b tensor([[0.3539, 0.6110], [0.5749, 0.0377]]) >>> a@b tensor([[0.9580, 0.5614], [0.8612, 0.6192]]) >>> torch.all(torch.eq(torch.matmul(a,b),a@b)) tensor(True)
pow指數
>>> torch.full([2,2],3) tensor([[3., 3.], [3., 3.]]) >>> torch.full([2,2],3)**3 tensor([[27., 27.], [27., 27.]]) >>> torch.full([2,2],3).pow(2) tensor([[9., 9.], [9., 9.]]) >>> torch.full([2,2],3).pow(2).sqrt() tensor([[3., 3.], [3., 3.]]) >>> torch.full([2,2],3).pow(2).rsqrt() tensor([[0.3333, 0.3333], [0.3333, 0.3333]])
note:rsqrt倒數平方根
exp log
>>> torch.exp(torch.full([2,2],1)) tensor([[2.7183, 2.7183], [2.7183, 2.7183]]) >>> torch.log((torch.exp(torch.full([2,2],1)))) tensor([[1., 1.], [1., 1.]])
Approximation
- floor() .ceil()
- round()
- trunc() .frac()
>>> a=torch.tensor(3.14) >>> a.floor(),a.ceil(),a.trunc(),a.round(),a.frac() (tensor(3.), tensor(4.), tensor(3.), tensor(3.), tensor(0.1400))
clamp
gradient clipping
>>> a=torch.rand(2,2)*14 >>> a tensor([[ 1.3472, 5.9060], [12.0558, 4.2571]]) >>> a.clamp(10) tensor([[10.0000, 10.0000], [12.0558, 10.0000]]) >>> a.clamp(0,10) tensor([[ 1.3472, 5.9060], [10.0000, 4.2571]])
statistics 統計屬性
▪ norm
▪ mean sum
▪ prod
▪ max, min, argmin, argmax
▪ kthvalue, topk