torch中的幾個函數(min()、max()、prod()、clamp_()、where())


torch.min()、torch.max()、torch.prod() 

   這兩個函數很好理解,就是求張量中的最小值和最大值以及相乘

    1.在這兩個函數中如果沒有指定維度的話,那么默認是將張量中的所有值進行比較,輸出最大值或者最小值或是所有值相乘。

    2.而當指定維度之后,會將對應維度的數據進行比較,同時輸出的有最小值以及這個最小值在對應維度的下標,或是指定維度相乘

    3.使用這兩個函數對兩個張量進行比較時,輸出的是張量中每一個值對應位置的最小值,對應位置相乘

 1 >>> a = torch.Tensor([[1,2,3,4],[5,6,7,8]])
 2 >>> b = torch.Tensor([[2,1,4,3],[6,5,8,7]])
 3 >>> torch.min(a)
 4 tensor(1.)
 5 
 6 >>> torch.min(a,dim=0)
 7 torch.return_types.min(
 8 values=tensor([1., 2., 3., 4.]),
 9 indices=tensor([0, 0, 0, 0]))
10 
11 >>> torch.min(a,b)
12 tensor([[1., 1., 3., 3.],
13         [5., 5., 7., 7.]])
14 
15 >>> torch.min(a[:,2:],b[:,2:])
16 tensor([[3., 3.],
17         [7., 7.]])

torch.clamp_()

  這個函數其實就是對張量進行上下限的限制,超過了指定的上限或是下限之后,該值賦值為確定的界限的值

 

1 >>> a = torch.Tensor([[1,2,3,4],[5,6,7,8]])
2 >>> a.clamp_(min = 2.5,max = 6.5)
3 
4 tensor([[2.5000, 2.5000, 3.0000, 4.0000],
5         [5.0000, 6.0000, 6.5000, 6.5000]])

 

torch.where()

  函數的定義如下:

  1. torch.where(condition, x, y):
  2.  condition:判斷條件
  3.  x:若滿足條件,則取x中元素
  4.  y:若不滿足條件,則取y中元素
 1 import torch
 2 # 條件
 3 condition = torch.rand(3, 2)
 4 print(condition)
 5 # 滿足條件則取x中對應元素
 6 x = torch.ones(3, 2)
 7 print(x)
 8 # 不滿足條件則取y中對應元素
 9 y = torch.zeros(3, 2)
10 print(y)
11 # 條件判斷后的結果
12 result = torch.where(condition > 0.5, x, y)
13 print(result)
14 
15 
16 
17 
18 tensor([[0.3224, 0.5789],
19         [0.8341, 0.1673],
20         [0.1668, 0.4933]])
21 tensor([[1., 1.],
22         [1., 1.],
23         [1., 1.]])
24 tensor([[0., 0.],
25         [0., 0.],
26         [0., 0.]])
27 tensor([[0., 1.],
28         [1., 0.],
29         [0., 0.]])

  可以看到是對張量中的每一個值進行比較,單獨進行條件判斷,輸出張量對應的位置為判斷后對應判擇的輸出

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM