官方文档:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch/
顾名思义,返回非零元素的索引:
先从一维的看起:
import torch
x = torch.Tensor([0, 1, 2, 3, 0, 5])
y = torch.nonzero(x)
print(y)
print(y.size())
输出:
tensor([[1],
[2],
[3],
[5]])
torch.Size([4, 1])
注意这里,首先输出y的维度为4×1,4表示非零元素的数量,1表示我们的输入维度。
再看一个二维的例子:
import torch
x = torch.Tensor([
[0, 1, 2, 3, 0, 5],
[0, 0, 0, 0, 0, 10]
])
y = torch.nonzero(x)
print(y)
print(y.size())
for i in y:
print(i[0].numpy(), i[1].numpy())
输出:
tensor([[0, 1],
[0, 2],
[0, 3],
[0, 5],
[1, 5]])
torch.Size([5, 2])
0 1
0 2
0 3
0 5
1 5
输出y的维度为5×2,5表示非零元素的数量,2表示我们的输入维度。输出的每一行提供了相应非零元素的索引。