Pytorch中的torch.where函數


首先我們看一下Pytorch中torch.where函數是怎樣定義的:

1 @overload
2 def where(condition: Tensor) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...

torch.where函數的功能如下:

1 torch.where(condition, x, y):
2 condition:判斷條件
3 x:若滿足條件,則取x中元素
4 y:若不滿足條件,則取y中元素

以具體實例看一下torch.where函數的效果:

 1 import torch
 2  
 3 # 條件
 4 condition = torch.rand(3, 2)
 5 print(condition)
 6 # 滿足條件則取x中對應元素
 7 x = torch.ones(3, 2)
 8 print(x)
 9 # 不滿足條件則取y中對應元素
10 y = torch.zeros(3, 2)
11 print(y)
12 # 條件判斷后的結果
13 result = torch.where(condition > 0.5, x, y)
14 print(result)

結果如下:

 1 tensor([[0.3224, 0.5789],
 2         [0.8341, 0.1673],
 3         [0.1668, 0.4933]])
 4 tensor([[1., 1.],
 5         [1., 1.],
 6         [1., 1.]])
 7 tensor([[0., 0.],
 8         [0., 0.],
 9         [0., 0.]])
10 tensor([[0., 1.],
11         [1., 0.],
12         [0., 0.]])

可以看到torch.where函數會對condition中的元素逐一進行判斷,根據判斷的結果選取x或y中的值,所以要求x和y應該與condition形狀相同。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM