torch.argwhere
torch.argwhere(input) → Tensor
返回一个张量,该张量包含输入中所有非零元素的下标。
对于一个n维的输入,输出结果的下标的大小是z×n维的,z是所有非零元素的总个数。
t = torch.tensor([1, 0, 1])
torch.argwhere(t)
tensor([[0],
[2]])
t = torch.tensor([[1, 0, 1], [0, 1, 1]])
torch.argwhere(t)
tensor([[0, 0],
[0, 2],
[1, 1],
[1, 2]])