torch.nonzero
一维
import torch
input = torch.randint(15, 17, (6, 2))
rels_mask = torch.nonzero(input[:, 0] - 15)
print(input)
print(input[:, 0].shape)
print('----------------------')
print(rels_mask)
print(rels_mask.shape)
输出
tensor([[15, 15],
[16, 15],
[16, 16],
[16, 16],
[15, 15],
[16, 16]])
torch.Size([6])
----------------------
tensor([[1],
[2],
[3],
[5]])
torch.Size([4, 1])
二维、三维
参考
https://blog.csdn.net/monchin/article/details/79750216