torch.min()、torch.max()、torch.prod()
这两个函数很好理解,就是求张量中的最小值和最大值以及相乘
1.在这两个函数中如果没有指定维度的话,那么默认是将张量中的所有值进行比较,输出最大值或者最小值或是所有值相乘。
2.而当指定维度之后,会将对应维度的数据进行比较,同时输出的有最小值以及这个最小值在对应维度的下标,或是指定维度相乘
3.使用这两个函数对两个张量进行比较时,输出的是张量中每一个值对应位置的最小值,对应位置相乘
1 >>> a = torch.Tensor([[1,2,3,4],[5,6,7,8]]) 2 >>> b = torch.Tensor([[2,1,4,3],[6,5,8,7]]) 3 >>> torch.min(a) 4 tensor(1.) 5 6 >>> torch.min(a,dim=0) 7 torch.return_types.min( 8 values=tensor([1., 2., 3., 4.]), 9 indices=tensor([0, 0, 0, 0])) 10 11 >>> torch.min(a,b) 12 tensor([[1., 1., 3., 3.], 13 [5., 5., 7., 7.]]) 14 15 >>> torch.min(a[:,2:],b[:,2:]) 16 tensor([[3., 3.], 17 [7., 7.]])
torch.clamp_()
这个函数其实就是对张量进行上下限的限制,超过了指定的上限或是下限之后,该值赋值为确定的界限的值
1 >>> a = torch.Tensor([[1,2,3,4],[5,6,7,8]]) 2 >>> a.clamp_(min = 2.5,max = 6.5) 3 4 tensor([[2.5000, 2.5000, 3.0000, 4.0000], 5 [5.0000, 6.0000, 6.5000, 6.5000]])
torch.where()
函数的定义如下:
-
torch.where(condition, x, y):
-
condition:判断条件
-
x:若满足条件,则取x中元素
-
y:若不满足条件,则取y中元素
1 import torch 2 # 条件 3 condition = torch.rand(3, 2) 4 print(condition) 5 # 满足条件则取x中对应元素 6 x = torch.ones(3, 2) 7 print(x) 8 # 不满足条件则取y中对应元素 9 y = torch.zeros(3, 2) 10 print(y) 11 # 条件判断后的结果 12 result = torch.where(condition > 0.5, x, y) 13 print(result) 14 15 16 17 18 tensor([[0.3224, 0.5789], 19 [0.8341, 0.1673], 20 [0.1668, 0.4933]]) 21 tensor([[1., 1.], 22 [1., 1.], 23 [1., 1.]]) 24 tensor([[0., 0.], 25 [0., 0.], 26 [0., 0.]]) 27 tensor([[0., 1.], 28 [1., 0.], 29 [0., 0.]])
可以看到是对张量中的每一个值进行比较,单独进行条件判断,输出张量对应的位置为判断后对应判择的输出