首先我們看一下Pytorch中torch.where函數是怎樣定義的:
1 @overload 2 def where(condition: Tensor) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...
torch.where函數的功能如下:
1 torch.where(condition, x, y): 2 condition:判斷條件 3 x:若滿足條件,則取x中元素 4 y:若不滿足條件,則取y中元素
以具體實例看一下torch.where函數的效果:
1 import torch 2 3 # 條件 4 condition = torch.rand(3, 2) 5 print(condition) 6 # 滿足條件則取x中對應元素 7 x = torch.ones(3, 2) 8 print(x) 9 # 不滿足條件則取y中對應元素 10 y = torch.zeros(3, 2) 11 print(y) 12 # 條件判斷后的結果 13 result = torch.where(condition > 0.5, x, y) 14 print(result)
結果如下:
1 tensor([[0.3224, 0.5789], 2 [0.8341, 0.1673], 3 [0.1668, 0.4933]]) 4 tensor([[1., 1.], 5 [1., 1.], 6 [1., 1.]]) 7 tensor([[0., 0.], 8 [0., 0.], 9 [0., 0.]]) 10 tensor([[0., 1.], 11 [1., 0.], 12 [0., 0.]])
可以看到torch.where函數會對condition中的元素逐一進行判斷,根據判斷的結果選取x或y中的值,所以要求x和y應該與condition形狀相同。