Pytorch的scatter函数详解

前言

 在看FCOS算法源码时,发现获取正样本点用到了scatter这个函数,故记录下。

1、官方文档解释

  先贴出链接:scatter官方解读

Tensor.scatter_(dim, index, src, reduce=None) → Tensor

 接收三个参数: dim, index和src。该函数作用就是在dim维度上,根据index提供的索引,从src中提取对应元素来赋值给Tensor。 以下是官方给的一个三维张量例子。
在这里插入图片描述
 需要注意两个点:index和src的dim维度数必须一样! 以官方3-D tensor为例,即self、src和index的维度均为3;若是2D-tensor则self、src和index的维度均为2。因为需要用index的元素作为索引,故index中元素的大小必须<self.size(d) 且 src.size(d)。

2、举个例子

在这里插入图片描述
附上代码:

src = torch.Tensor([[0,1,2,3,4],[5,6,7,8,9]])
self = torch.zeros((3,5))
index = torch.tensor([[0,1,2,0,0],[2,0,0,1,2]])
self.scatter_(dim=0,index=index,src=src)
print(self)

输出:
在这里插入图片描述

总结

  在实际编程中,src往往是标量,即是个常数。根据定义,等式右边的src[i][j] 恒等于标量。即此时scatter函数作用就是根据index将self中对应位置变成常数即可。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值