pytorch里torch.gather()和torch.Tensor.scatter()解析

torch.Tensor.scatter() 类似 gather 的反向操作(gather是读出数据,scatter是写入数据),所以这里只解析torch.gather()。
gather()这个操作在功能上较为反人类,即使某段时间理解透彻了,过了几个月不碰可能又会变得生疏。官方文档对其描述也是较为简单,有些小伙伴看完可能还是不完全理解,本文从根本上去解析这个操作的功能。
概括地说,gather()是index_select()的延伸操作,比index_select()更加灵活,它的操作不属于块操作,而是元素级别的操作,所以性能上应该较低,我们应该尽可能地避免使用这个操作。

下面开始解析这个操作。

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

这个功能的设计目的是“Gathers values along an axis specified by dim.”,这是官方文档的所有描述,看到这句话做目标检测的小伙伴应该能想到这样一个场景:

目标检测网络输出矩阵,前4列是box的坐标,第5列表示检测到目标的种类标签
print(pred)
tensor([[0.0080, 0.6403, 0.9865, 0.0158, 1.0000],
        [0.2742, 0.7470, 0.3837, 0.6689, 3.0000],
        [0.3260, 0.6683, 0.1888, 0.9525, 0.0000],
        [0.7989, 0.9154, 0.1040, 0.5538, 3.0000],
        [0.6746, 0.6193, 0.0161, 0.5166, 0.0000]])

现在我们要挑选出标签是3的所有检测目标框,

i = pred[:, 4].eq(3).nonzero().repeat(1, 4)
torch.gather(pred, 0, i)
tensor([[0.2742, 0.7470, 0.3837, 0.6689],
        [0.7989, 0.9154, 0.1040, 0.5538]])

gather()可以实现实现这种整行地抽取数据,但不是最优的实现方法,我们有更合适的实现方法,index_select()和下标索引:

i = pred[:, 4].eq(3).nonzero().squeeze()
pred.index_select(0, i)[:, :4]
tensor([[0.2742, 0.7470, 0.3837, 0.6689],
        [0.7989, 0.9154, 0.1040, 0.5538]])
        
# 下标索引方法        
pred[i, :4]
tensor([[0.2742, 0.7470, 0.3837, 0.6689],
        [0.7989, 0.9154, 0.1040, 0.5538]])

现在我们要进行更加复杂的数据抽取,输出张量的要求如下:

  • shape是2*2
  • 第0行的第0列对应原始数据pred的第0行第0列,第1列对应pred的第1行第1列,在图中红色元素
  • 第1行的第0列对应原始数据pred的第3行第0列,第1列对应pred的第2行第1列,在图中蓝色元素
    数据抽取

这时候index_selsect()无法实现,但gather()可以

index = torch.tensor([[0, 1],
				  	  [3, 2]])
torch.gather(pred, 0, index)
tensor([[0.0080, 0.7470],
        [0.7989, 0.6683]])

这个操作的规则如下:

  • 输出张量的shape和索引张量(index)相同

  • 除了dim指示的那个维度,其他所有的维度满足条件: index.size(d) <= input.size(d)

  • index和输入张量input的每个维度一一对应

  • 除了dim指示的那个维度,其他维度的input和output元素位置对应,当index.size(d) < input.size(d)时候,从最前面截取
    在这里插入图片描述
    在这里插入图片描述

  • dim指示的那个维度上数据根据index里具体元素指示的位置去定位

看起来还是不好理解的,好在这个函数的应用场景不多,到目前为止我还没遇到适合这个函数的应用场景,如果哪位小伙伴遇到了请评论区留言感激不尽。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值