torch_scatter.scatter_add
官方文档:torch_scatter.
scatter_add
(src, index, dim=-1, out=None, dim_size=None, fill_value=0)
Sums all values from the src
tensor into out
at the indices specified in the index
tensor along a given axis dim
. For each value in src
, its output index is specified by its index in input
for dimensions outside of dim
and by the corresponding value in index
for dimension dim
. If multiple indices reference the same location, their contributions add.
看着挺疑惑的,自己试了一把:
src = torch.tensor([10, 20, 30, 40, 1, 2, 2, 2, 9])
index = torch.tensor([2, 1, 1, 1, 1, 1, 1, 1, 0])
out=scatter_add(src, index)
print(out)
输出结果为:tensor([ 9, 97, 10])
说白了就是:index就是out的下标,将src所有和此下标对应的值加起来,就是out的值。
例如上面的例子:index中等于1的,对应于src是【20, 30, 40, 1, 2, 2, 2】,将这些值加起来是97,于是,out[1]=97
同理:out[0]=src[8]=9 out[2]=src[0]=10
另一个函数
Tensor.scatter_add_
官方文档:
scatter_add_(self, dim, index, other):
For a 3-D tensor, :attr:`self` is updated as::
self[index[i][j][k]][j][k] += other[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] += other[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] += other[i][j][k] # if dim == 2
官方例子:
>>> x = torch.rand(2, 5)
>>> x
tensor([[0.7404, 0.0427, 0.6480, 0.3806, 0.8328],
[0.7953, 0.2009, 0.9154, 0.6782, 0.9620]])
>>> torch.ones(3, 5).scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[1.7404, 1.2009, 1.9154, 1.3806, 1.8328],
[1.0000, 1.0427, 1.0000, 1.6782, 1.0000],
[1.7953, 1.0000, 1.6480, 1.0000, 1.9620]])
以index来遍历,就比较容易看懂。self中并不是每个值都要改变的。
以上面为例 index[0][0]=0 self[index[0][0]][0]=self[0][0] =self[0][0]+ x[0][0]=1 +0.7404=1.7404
index[0][1]=1 self[index[0][1]][1]=self[1][1] =self[1][1]+ x[0][1] =1 +0.0427 =1.0427
。。。
以此类推,将index遍历一遍,就得到最终的结果
所以,self中需要改变的是index中列出的坐标,其他的是不动的。
Tensor.scatter_
scatter_(self, dim, index, src)
和Tensor.scatter_add_的区别是直接将src中的值填充到self中,不做相加
例子:
>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004],
[ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000],
[ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]])
>>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
>>> z
tensor([[ 0.0000, 0.0000, 1.2300, 0.0000],
[ 0.0000, 0.0000, 0.0000, 1.2300]])
另外,pytorch中还有
scatter_add和scatter函数,和上面两个函数不同的是这个两个函数不改变self,会返回结果值;上面两个函数(scatter_add_和scatter_)是直接在原数据self上进行修改