torch.scatter

一、函数简介

torch.scatter(input, dim, index, src)

  • dim ([int]) – the axis along which to index
  • index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged.
  • src ([Tensor] or [float] – the source element(s) to scatter.
  • reduce ([str], optional) – reduction operation to apply, can be either 'add' or 'multiply'.

将src中的数据根据index中的索引按照dim的方向填入到input中

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

看了上述的官方文档还是不理解,我们继续看官方的例子,这里官方只给了三维,我在这里又加入了二维,在这之前有一个规定

  • 对任意维度d:index.size(d) <= src.size(d)
  • 对d!=dim的维度:index.size(d) <= self.size(d)

二、二维举例

self[index[i][j]][j] = src[i][j] # if dim == 0
self[i][index[i][j]] = src[i][j] # if dim == 1

先上代码

torch.manual_seed(0)
x = torch.arange(0, 12).reshape(2, 6)
x= x.type(torch.float32)
print(x)
'''
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10., 11.]])
'''


index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
print(index)
'''
tensor([[0, 1, 2, 0, 0],
        [2, 0, 0, 1, 2]])
'''


y = torch.zeros(3, 6)
y = torch.scatter(y, 0, index, x)
print(y)
'''
tensor([[ 0.,  7.,  8.,  3.,  4.,  0.],
        [ 0.,  1.,  0.,  9.,  0.,  0.],
        [ 6.,  0.,  2.,  0., 10.,  0.]])
'''


'''in-place operation'''
yy = torch.zeros(3, 6)
yy.scatter_(0, index, x)
print(yy)
'''
tensor([[ 0.,  7.,  8.,  3.,  4.,  0.],
        [ 0.,  1.,  0.,  9.,  0.,  0.],
        [ 6.,  0.,  2.,  0., 10.,  0.]])
'''

下面将上述的执行过程绘制出来

三、详解执行过程

1. 第一步

首先下面是我们的初始化,我们初始化了srcindexinput并且设置了dim=0
填充公式为self [ index[i][j] ][j] = src[i][j]

因为dim=0,所以需要填充的input的行的索引就由index数值也就是index[i][j]来确定,需要填充的input的列的索引对应于index的列,往self里面填充的具体数值由index对应的src来确定

看下面例子序号3

  • 需要填充的input的行的索引:行=index[i][j]=index[0][2]=2
  • 需要填充的input的列的索引:列=index列=j=2
  • self填充的具体数值:self[行][列]=self[2][2]=src[i][j]=2

2. 第二步

下面我们继续上述步骤

可发现当我们进行到第六步的时候,index[0][5]并不存在,所以直接跳过就可以了

3. 第三步

在这一步我们将input填充完毕

如图所示,这里我们取图中的序号11进行验证

序号11

  • 需要填充的input的行的索引:行=index[i][j]=index[1][4]=2
  • 需要填充的input的列的索引:列=index列=j=4
  • self填充的具体数值:self[行][列]=self[2][4]=src[i][j]=10
  • 所以在self的第二行,第四列填入10

4. 问题

为什么在第二步我们遇到的问题吗:当我们进行到序号6的时候,index[0][5]并不存在,我们选择了跳过

可以跳过而没有报错呢,因为最初的文档对src, index, self的维度有过定义

  • 对任意维度d:index.size(d) <= src.size(d)
  • 对d!=dim的维度:index.size(d) <= self.size(d)

所以index的维度是可以小于src的维度的,关系如下在这里插入图片描述

  • 14
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值