torch.where(condition, x, y) → Tensor
Return a tensor of elements selected from either x or y, depending on condition.
condition = torch.tensor([[0.6, 0.7], [0.8, 0.4]])
print(condition)
x = torch.zeros(2,2)
print(x)
y = torch.ones(2,2)
print(y)
# 输出
tensor([[0.6000, 0.7000],
[0.8000, 0.4000]])
tensor([[0., 0.],
[0., 0.]])
tensor([[1., 1.],
[1., 1.]])
print(torch.where(condition>0.5, x, y))
# 输出
tensor([[0., 0.],
[0., 1.]])