代码实例
mm = torch.tensor([[1, 3, 5, 1], [1, 4, 1, 8]])
print('mm:\n', mm)
print('mm.data:\n', mm.data)
print('mm.data.eq(1):\n', mm.data.eq(1))
print('mm.data.eq(1).nonzero():\n', mm.data.eq(1).nonzero())
运行结果
mm:
tensor([[1, 3, 5, 1],
[1, 4, 1, 8]])
mm.data:
tensor([[1, 3, 5, 1],
[1, 4, 1, 8]])
mm.data.eq(1):
tensor([[ True, False, False, True],
[ True, False, True, False]])
mm.data.eq(1).nonzero():
tensor([[0, 0],
[0, 3],
[1, 0],
[1, 2]])
*.nonzero()可以将数组中的bool型转换为相对应的坐标。