1. torch.where(condition, x, y) → Tensor
返回一个tensor,其中的元素根据condition从x或y中选取
out
i
=
{
x
i
if
condition
i
y
i
otherwise
\text { out }_{i}= \begin{cases}\mathrm{x}_{i} & \text { if } \text { condition }_{i} \\ \mathrm{y}_{i} & \text { otherwise }\end{cases}
out i={xiyi if condition i otherwise
case 1:
>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620, 0.3139],
[ 0.3898, -0.7197],
[ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000, 0.3139],
[ 0.3898, 1.0000],
[ 0.0478, 1.0000]])
>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779, 0.0383],
[-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],
[0.0000, 0.0000]], dtype=torch.float64)
2. torch.where(condition) → tuple of LongTensor
这种情况下torch.where完全等同于torch.nonzero(condition, as_tuple=True)
,即找出condition不等于0的元素的索引,当condition是二维tensor时,返回一个n*2
的tensor,condition是三维时,返回n*3
的tensor.
case2:
torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True)
(tensor([0, 1, 2, 4]),)
>>> a = torch.randn(4,3)
>>> a
tensor([[-0.1272, 1.0898, 0.1808],
[-0.6453, -1.5093, 2.1270],
[ 0.3852, 1.0197, -0.2998],
[-0.2550, 1.2353, -0.1208]])
>>> b = (a>0.1)
>>> b
tensor([[False, True, True],
[False, False, True],
[ True, True, False],
[False, True, False]])
>>> b.nonzero()
tensor([[0, 1],
[0, 2],
[1, 2],
[2, 0],
[2, 1],
[3, 1]])
>>> a = torch.randn(2,3,4)
>>> a
tensor([[[ 0.0642, -2.1055, -1.0024, 0.6682],
[-0.1130, -0.5724, 0.4137, -0.4402],
[ 1.3719, 0.7411, 1.7062, -0.9893]],
[[-0.1432, 0.4898, -0.4913, 1.1397],
[-0.0883, 1.1581, -0.4784, 0.0195],
[ 1.3767, 2.2682, -0.6881, 1.9919]]])
>>> b = (a>0.1)
>>> b
tensor([[[False, False, False, True],
[False, False, True, False],
[ True, True, True, False]],
[[False, True, False, True],
[False, True, False, False],
[ True, True, False, True]]])
>>> b.nonzero()
tensor([[0, 0, 3],
[0, 1, 2],
[0, 2, 0],
[0, 2, 1],
[0, 2, 2],
[1, 0, 1],
[1, 0, 3],
[1, 1, 1],
[1, 2, 0],
[1, 2, 1],
[1, 2, 3]])