torch.nonzero()记录tensor中零元素的位置信息

文章介绍了PyTorch中的torch.nonzero()函数,用于获取输入张量中非零元素的索引。通过实例演示了如何使用该函数处理二维和三维张量,以及输出结果的解读。
摘要由CSDN通过智能技术生成

函数作用:

返回一个包含输入input中非零元素索引的张量。输出张量中的每行包含input中非零元素的索引(也就是一行记录一个非零元素的位置信息)。如果输入input 有n维,则输出的索引张量out的size为 z*n,z 表示输入张量 input 中所有非零元素的个数。

print(torch.nonzero(torch.Tensor([[0.7, 0.0, 0.0, 0.0],
                             [0.0, 0.3, 0.0, 0.0],
                             [0.0, 0.0, 1.8, 0.0],
                             [0.0, 0.0, 0.0,-0.4]])))
输出:
tensor([[0, 0],
        [1, 1],
        [2, 2],
        [3, 3]])

输出out的理解:这个例子input是2维的,一共有4个非0元素,所以输出是一个4×2的张量,表示每个非0元素的索引(位置信息)。如out的第0行[0,0]表示input第一个非0元素的位置信息,即表示input的第0行的第0个元素是非0元素,同理,out的第1行[1,1],表示input第1行第1个元素是非0元素,等等。
 再把input设为一个3维张量,

a = torch.Tensor([
                   [
                     [1,1,1,0,1],
                     [1,0,0,0,1]
                   ],
                   [
                     [1,1,1,0,1],
                     [1,0,0,0,1]
                   ]
                ])
print(a)

b = torch.nonzero(a)
print(b)
print(b[:,0])
print(b[:,1])
print(b[:,2])
输出:
tensor([[[1., 1., 1., 0., 1.],
         [1., 0., 0., 0., 1.]],

        [[1., 1., 1., 0., 1.],
         [1., 0., 0., 0., 1.]]])
tensor([[0, 0, 0],
        [0, 0, 1],
        [0, 0, 2],
        [0, 0, 4],
        [0, 1, 0],
        [0, 1, 4],
        [1, 0, 0],
        [1, 0, 1],
        [1, 0, 2],
        [1, 0, 4],
        [1, 1, 0],
        [1, 1, 4]])
tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
tensor([0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1])
tensor([0, 1, 2, 4, 0, 4, 0, 1, 2, 4, 0, 4])

这个out张量的意思是:
  out大小为12*3,说明input有12个非0元素,3表示input的维度是3。第0行第0列第0个元素非0,第0行第0列第1个元素非0,……,第1行第1列第0个元素非0,第1行第1列第4个元素非0.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值