函数原型:
torch.nonzero(input, out=None) → LongTensor
参数:
input (Tensor) – 源张量
out (LongTensor, optional) – 包含索引值的结果张量
代码示例
返回一个包含输入input中非零元素索引的张量。输出张量中的每行包含输入中非零元素的索引。
x = torch.tensor([0, 0, 1, 5, 8])
y = torch.nonzero(x)
print(y)
print(y.shape)
>>>
tensor([[2],
[3],
[4]])
torch.Size([3, 1])
如果输入input有n维,则输出的索引张量output的形状为 z x n, 这里 z 是输入张量input中所有非零元素的个数。
x = torch.tensor([[0, 0, 1, 5],
[1, 5, 0, 8],
[2, 8, 9, 0]])
y = torch.nonzero(x)
print(y)
print(y.shape)
>>>
tensor([[0, 2],
[0, 3],
[1, 0],
[1, 1],
[1, 3],
[2, 0],
[2, 1],
[2, 2]])
torch.Size([8, 2])
pytorch文档学习链接:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch/