numpy.where 和 torch.where 如果不指定条件,返回的都是非0元素的位置(下标),
np.where
>>> a = np.array([-1, -2, 0, 1, 2])
>>> np.where(a)
(array([0, 1, 3, 4]),)
torch.where不指定条件,返回的也是非0元素的位置(下标),包括了负数。
看个例子
>>> c
tensor([[[ 0., 0.],
[-1., -1.],
[-1., -2.]],
[[ 1., 1.],
[ 0., 0.],
[ 0., -1.]],
[[ 1., 2.],
[ 0., 1.],
[ 0., 0.]]])
>>> torch.where(c)
(tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 2]),
tensor([1, 1, 2, 2, 0, 0, 2, 0, 0, 1]),
tensor([0, 1, 0, 1, 0, 1, 1, 0, 1, 1]))
竖着按列看,可看到非零元素的下标为(0, 1, 0), (0, 1, 1)等,包含了负数。
*注:(0, 1, 0)指的是第一块,第2行的第一个元素-1。
现在指定条件 > 0
>>> torch.where(c > 0)
(tensor([1, 1, 2, 2, 2]),
tensor([0, 0, 0, 0, 1]),
tensor([0, 1, 0, 1, 1]))
可以看到负数的下标不见了。
指定条件下它们也一样,见链接。