一、函数简介
torch.scatter(input, dim, index, src)
- dim ([int]) – the axis along which to index
- index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as
src
. When empty, the operation returnsself
unchanged. - src ([Tensor] or [float] – the source element(s) to scatter.
- reduce ([str], optional) – reduction operation to apply, can be either
'add'
or'multiply'
.
将src中的数据根据index中的索引按照dim的方向填入到input中
Writes all values from the tensor
src
intoself
at the indices specified in theindex
tensor. For each value insrc
, its output index is specified by its index insrc
fordimension != dim
and by the corresponding value inindex
fordimension = dim
.
看了上述的官方文档还是不理解,我们继续看官方的例子,这里官方只给了三维,我在这里又加入了二维,在这之前有一个规定
- 对任意维度d:
index.size(d) <= src.size(d)
- 对d!=dim的维度:
index.size(d) <= self.size(d)
二、二维举例
self[index[i][j]][j] = src[i][j] # if dim == 0
self[i][index[i][j]] = src[i][j] # if dim == 1
先上代码
torch.manual_seed(0)
x = torch.arange(0, 12).reshape(2, 6)
x= x.type(torch.float32)
print(x)
'''
tensor([[ 0., 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10., 11.]])
'''
index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
print(index)
'''
tensor([[0, 1, 2, 0, 0],
[2, 0, 0, 1, 2]])
'''
y = torch.zeros(3, 6)
y = torch.scatter(y, 0, index, x)
print(y)
'''
tensor([[ 0., 7., 8., 3., 4., 0.],
[ 0., 1., 0., 9., 0., 0.],
[ 6., 0., 2., 0., 10., 0.]])
'''
'''in-place operation'''
yy = torch.zeros(3, 6)
yy.scatter_(0, index, x)
print(yy)
'''
tensor([[ 0., 7., 8., 3., 4., 0.],
[ 0., 1., 0., 9., 0., 0.],
[ 6., 0., 2., 0., 10., 0.]])
'''
下面将上述的执行过程绘制出来
三、详解执行过程
1. 第一步
首先下面是我们的初始化,我们初始化了
src
,index
,input
并且设置了dim=0
填充公式为self [ index[i][j] ][j] = src[i][j]
因为dim=0
,所以需要填充的input的行的索引就由index数值也就是index[i][j]来确定,需要填充的input的列的索引对应于index的列,往self里面填充的具体数值由index对应的src来确定
看下面例子序号3:
- 需要填充的input的行的索引:
行=index[i][j]=index[0][2]=2
- 需要填充的input的列的索引:
列=index列=j=2
- self填充的具体数值:
self[行][列]=self[2][2]=src[i][j]=2
2. 第二步
下面我们继续上述步骤
可发现当我们进行到第六步的时候,index[0][5]
并不存在,所以直接跳过就可以了
3. 第三步
在这一步我们将input
填充完毕
如图所示,这里我们取图中的序号11进行验证
序号11:
- 需要填充的input的行的索引:
行=index[i][j]=index[1][4]=2
- 需要填充的input的列的索引:
列=index列=j=4
- self填充的具体数值:
self[行][列]=self[2][4]=src[i][j]=10
- 所以在self的第二行,第四列填入10
4. 问题
为什么在第二步我们遇到的问题吗:当我们进行到序号6的时候,index[0][5]
并不存在,我们选择了跳过
可以跳过而没有报错呢,因为最初的文档对src, index, self
的维度有过定义
- 对任意维度d:
index.size(d) <= src.size(d)
- 对d!=dim的维度:
index.size(d) <= self.size(d)
所以index
的维度是可以小于src
的维度的,关系如下