torch中的几个函数(min()、max()、prod()、clamp_()、where())


torch.min()、torch.max()、torch.prod() 

   这两个函数很好理解,就是求张量中的最小值和最大值以及相乘

    1.在这两个函数中如果没有指定维度的话,那么默认是将张量中的所有值进行比较,输出最大值或者最小值或是所有值相乘。

    2.而当指定维度之后,会将对应维度的数据进行比较,同时输出的有最小值以及这个最小值在对应维度的下标,或是指定维度相乘

    3.使用这两个函数对两个张量进行比较时,输出的是张量中每一个值对应位置的最小值,对应位置相乘

 1 >>> a = torch.Tensor([[1,2,3,4],[5,6,7,8]])
 2 >>> b = torch.Tensor([[2,1,4,3],[6,5,8,7]])
 3 >>> torch.min(a)
 4 tensor(1.)
 5 
 6 >>> torch.min(a,dim=0)
 7 torch.return_types.min(
 8 values=tensor([1., 2., 3., 4.]),
 9 indices=tensor([0, 0, 0, 0]))
10 
11 >>> torch.min(a,b)
12 tensor([[1., 1., 3., 3.],
13         [5., 5., 7., 7.]])
14 
15 >>> torch.min(a[:,2:],b[:,2:])
16 tensor([[3., 3.],
17         [7., 7.]])

torch.clamp_()

  这个函数其实就是对张量进行上下限的限制,超过了指定的上限或是下限之后,该值赋值为确定的界限的值

 

1 >>> a = torch.Tensor([[1,2,3,4],[5,6,7,8]])
2 >>> a.clamp_(min = 2.5,max = 6.5)
3 
4 tensor([[2.5000, 2.5000, 3.0000, 4.0000],
5         [5.0000, 6.0000, 6.5000, 6.5000]])

 

torch.where()

  函数的定义如下:

  1. torch.where(condition, x, y):
  2.  condition:判断条件
  3.  x:若满足条件,则取x中元素
  4.  y:若不满足条件,则取y中元素
 1 import torch
 2 # 条件
 3 condition = torch.rand(3, 2)
 4 print(condition)
 5 # 满足条件则取x中对应元素
 6 x = torch.ones(3, 2)
 7 print(x)
 8 # 不满足条件则取y中对应元素
 9 y = torch.zeros(3, 2)
10 print(y)
11 # 条件判断后的结果
12 result = torch.where(condition > 0.5, x, y)
13 print(result)
14 
15 
16 
17 
18 tensor([[0.3224, 0.5789],
19         [0.8341, 0.1673],
20         [0.1668, 0.4933]])
21 tensor([[1., 1.],
22         [1., 1.],
23         [1., 1.]])
24 tensor([[0., 0.],
25         [0., 0.],
26         [0., 0.]])
27 tensor([[0., 1.],
28         [1., 0.],
29         [0., 0.]])

  可以看到是对张量中的每一个值进行比较,单独进行条件判断,输出张量对应的位置为判断后对应判择的输出

 


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM