input
import torch
a = torch.tensor([1,2,3,4,5])
# 取布尔值
print(a<4)
print(torch.lt(a,4)) #lt gt eq le ge
# 取值(下面3中写法可以达到同样的效果)
print(a[a < 4])
print(torch.masked_select(a, a < 4))
print(torch.masked_select(a, torch.lt(a, 4)))
# 取索引,默认是非0元素的索引
print(torch.nonzero(a < 4, as_tuple=False))
output
tensor([ True, True, True, False, False])
tensor([ True, True, True, False, False])
tensor([1, 2, 3])
tensor([1, 2, 3])
tensor([1, 2, 3])
tensor([[0],
[1],
[2]])