torch.scatter算子详解

0 scatter理解

关于该算子, torch 官方的文档https://pytorch-cn.readthedocs.io/zh/latest/package_references/Tensor/#scatter_input-dim-index-src-tensor 是这么解释的:
在这里插入图片描述
刚开始看了几次, 挺费解的。仔细理解之后, 发现其实用法也挺简单的。 这个操作的作用就是把src这个Tensor的值给更新到input这个Tensor中。 那更新到哪些位置呢, 就是由index 和dim去确定了。

拿上面的例子来说, 输入input是一个维度为[3,5]的Tensor, src是一个维度为[2,5]的Tensor。 那么很自然的, 要把src中这10个值更新到input中, 显然需要10个位置索引去决定更新到哪些位置。 那么如何去表示这10个索引呢, 这里用index 和dim两个值共同确定。 首先dim=0, 表示index是按第0维度(也就是行)去索引input的。

index的第一行是[0,1,2,0,0] 就表示把src中的第一行的5个值分别更新到input的第0行, 第1行, 第2行, 第0行和第0行。 列是相同的, 也就是src中第1列对应的也是input中的第一列。 其实把index写完整就更好理解了,完整的index应该是[[0,0],[1,1],[2,2],[0,3],[0,4]]. 因为列是相同的, 所以省去了列值, 这导致有一些不好理解。 其实如果写出完整的index也就不需要dim这个参数了。 之所以不用完整的索引值, 而是用不完全的index和dim共同确定最终的index, 应该是为了简化index的写法。

index的第一行是[2,0,0,1,2] 就表示把src中的第一行的5个值分别更新到input的第2行, 第0行, 第0行, 第1行和第2行。完整的index应该是[[2,0],[0,1],[0,2],[1,3],[2,4]].

下面的图非常直观的表示了这一过程:
在这里插入图片描述

1 作用

上面花了较大的篇幅介绍了scatter的具体作用,看着还挺复杂的。 那么这个操作到底有什么用呢? 实际上, 这个操作基本上都用在one-hot的操作中。 one-hot的操作中, 就需要用到这个操作, 把索引指向位置的值更新为1.

  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值