input
import torch
a = torch.tensor([[1, 2], [3, 4], [5, 6]])
# 取出对于元素的布尔值
print(a)
print(a > 3) #得到的是一系列布尔值
print(a[a > 3]) #取出数组中索引为True的的值(a中大于3的值)
# 取出每个轴的索引,默认是非0元素的索引(取出a中的大于3的元素对应的索引)
print(torch.nonzero(a > 3, as_tuple=False))
output
tensor([[1, 2],
[3, 4],
[5, 6]])
tensor([[False, False],
[False, True],
[ True, True]])
tensor([4, 5, 6])
tensor([[1, 1],
[2, 0],
[2, 1]])
input
import torch
b = torch.tensor([[0, 1, 0, 5, 6, 8], [1, 2, 0, 0, 5, 0], [1, 1, 5, 0, 0, 5]], dtype=torch.float32)
print(b)
"取出非0元素的索引"
print(b.nonzero())
print(len(b.reshape(-1)))
print(len(b.nonzero()))
output
tensor([[0., 1., 0., 5., 6., 8.],
[1., 2., 0., 0., 5., 0.],
[1., 1., 5., 0., 0., 5.]])
tensor([[0, 1],
[0, 3],
[0, 4],
[0, 5],
[1, 0],
[1, 1],
[1, 4],
[2, 0],
[2, 1],
[2, 2],
[2, 5]])
18
11