import torch
x = torch.randn(1, 3, 6, 6)
y = torch.zeros(x.shape).to(x.device)
y[x >= 0.5] = 1
z = (y == 1).nonzero(as_tuple = False)
print(z)
# z is the indices you need.
pytorch tensor 获取指定值的indices
最新推荐文章于 2023-09-19 13:04:13 发布