运算 函数
大于 torch.gt
小于 torch.lt
等于 torch.eq
非零 torch.nonzero
非 torch.ne
import torch
x = torch.arange(5)
print(x)
mask = torch.gt(x,1) # 大于
print(mask)
print(x[mask])
x = torch.arange(5)
print(x)
mask = torch.lt(x,3) # 小于
print(mask)
print(x[mask])
x = torch.arange(5)
print(x)
mask = torch.eq(x,3) # 等于
print(mask)
print(x[mask])
x = torch.Tensor([1,2,1,0,0])
mask = torch.ne(x,1) # 非,一个数
print(mask)
print(x[mask])
a = torch.Tensor([[0.6, 0.0, 0.0, 0.0],[0.0, 0.4, 0.0, 0.0],[0.0, 0.0, 1.2, 0.0],[0.0, 0.0, 0.0,-0.4]])
mask = torch.nonzero(a) # 非零
print(mask)
print(torch.numel(mask))
print(torch.numel(a))
# print(a[mask])
print(torch.numel(mask)/torch.numel(a))