- 官方文档 可供参考
torch.where(condition, input, other, *, out=None) → Tensor
- Return a tensor of elements selected from either
input
orother
, depending oncondition
. - Parameters:
condition (BoolTensor) – When True (nonzero), yield input, otherwise yield other
input (Tensor or Scalar) – value (if input is a scalar) or values selected at indices where condition is True
other (Tensor or Scalar) – value (if other is a scalar) or values selected at indices where condition is False
- Example:
>>> a = torch.arange(0,8).view(2,4)
>>> b = -torch.arange(0,8).view(2,4)
>>> torch.where(a>4, a, b)
tensor([[ 0, -1, -2, -3],
[-4, 5, 6, 7]])
>>> a>4
tensor([[False, False, False, False],
[False, True, True, True]])
对该 example 的解释如下:
结果的 shape
与 input: a
一致,结果中具体的值设置的方法如下:
- 如果对应位置
condition
为True
: 设置为input: a
中对应位置元素值,所以有tensor([[ *, *, *, *], [*, 5, 6, 7]])
- 如果对应位置
condition
为False
: 设置为other: b
中对应位置元素值,所以有tensor([[ 0, -1, -2, -3], [-4, *, *, *]])