今天读代码,看到有人用了torch.Tensor.scatter这个函数。这个函数之前我也看到过,但是没有搞明白是干啥用的,今天我搞明白了。
首先看一下官方文档的定义:(可跳过)
==============================pytorch docs分割线====================================
scatter_
(dim, index, src) → Tensor
Writes all values from the tensor src
into self
at the indices specified in the index
tensor. For each value in src
, its output index is specified by its index in src
for dimension != dim
and by the corresponding value in index
for dimension = dim
.
For a 3-D tensor, self
is updated as:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
This is the reverse operation of the manner described in gather()
.
self
, index
and src
(if it is a Tensor) should have same number of dimensions. It is also required that index.size(d) <= src.size(d)
for all dimensions d
, and that index.size(d) <= self.size(d)
for all dimensions d != dim
.
Moreover, as for gather()
, the values of index
must be between 0
and self.size(dim) - 1
inclusive, and all values in a row along the specified dimension dim
must be unique.
Parameters
-
dim (int) – the axis along which to index
-
index (LongTensor) – the indices of elements to scatter, can be either empty or the same size of src. When empty, the operation returns identity
-
src (Tensor) – the source element(s) to scatter, incase value is not specified
-
value (float) – the source element(s) to scatter, incase src is not specified
-
===============================================================================
通过自己的实验,我明白了这个函数的目的,下面讲解一下:
首先看一下这个函数的接口,需要三个输入:1)维度dim 2)索引数组index 3)原数组src,为了方便理解,我们后文把src换成input表示。最终的输出是新的output数组。
下面依次介绍:
1)维度dim:整数,可以是0,1,2,3...
2)索引数组index:索引数组是一个tensor,其中的数据类型是整数,表示位置
3)原数组input:也是一个tensor,其中的数据类型任意
先说一下这个函数是干嘛的,在我看来,这个scatter函数就是把input数组中的数据进行重新分配。index中表示了要把原数组中的数据分配到output数组中的位置,如果未指定,则填充0。
比如说下面这段代码:
import torch
input = torch.randn(2, 4)
print(input)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, input)
print(output)
运行结果如下:
tensor([[-0.2558, -1.8930, -0.7831, 0.6100],
[ 0.3246, 2.1289, 0.5887, 1.5588]])
tensor([[ 0.6100, -1.8930, -0.7831, -0.2558, 0.0000],
[ 0.5887, 0.3246, 2.1289, 1.5588, 0.0000]])
下面,我详细说一下为什么会是这样的结果。
前面说了,scatter是input数组,根据index数组,对input数组中的数据进行重新分配,我们看一下分配过程是怎样的。
input:
tensor([[-0.2558, -1.8930, -0.7831, 0.6100],
[ 0.3246, 2.1289, 0.5887, 1.5588]])
index:
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output:
tensor([[ 0.6100, -1.8930, -0.7831, -0.2558, 0.0000],
[ 0.5887, 0.3246, 2.1289, 1.5588, 0.0000]])
首先,对input[0][0]进行重分配。符号 -> 代表赋值。由于scatter方法的第一维dim=1,所以input数组中的数据只是在第1维上进行重新分配,第0维不变。以二维数组举例,第一行的数据重新分配后一定在还是第一行,不能跑到第二行。
input[0][0] -> output[0][index[0][0]] = output[0][3]
数据位置发生的变化都是在第1维上,第0维不变。
input[0][1] -> output[0][index[0][1]] = output[0][1]
input[0][2] -> output[0][index[0][2]] = output[0][2]
input[0][3] -> output[0][index[0][3]] = output[0][0]
需要注意的是,
为了方便理解,我们是按照input中数据的顺序索引的,但是在pytorch中,是根据从index[0][0]到index[0][3]这样的顺序去索引的,索引的input位置和output的位置必须要存在,否则会提示错误。但是,不一定所有的input数据都会分到output中,output也不是所有位置都有对应的input,当output中没有对应的input时,自动填充0。
一般scatter用于生成onehot向量,如下所示:
index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)
输出结果是:
tensor([[0., 1., 0., 0.],
[0., 0., 1., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 1.]])
如果input是一个数字的话,代表这用于分配到output的数字是多少。