import torch
x = torch.randn(1, 3, 6, 6)
y = torch.zeros(x.shape).to(x.device)
y[x >= 0.5] = 1
z = (y == 1).nonzero(as_tuple = False)
print(z)
# z is the indices you need.
import torch
x = torch.randn(1, 3, 6, 6)
y = torch.zeros(x.shape).to(x.device)
y[x >= 0.5] = 1
z = (y == 1).nonzero(as_tuple = False)
print(z)
# z is the indices you need.