关于scatter_add函数的用法
两种scatter函数的关系
torch_scatter
torch_scatter是pytorch_geometric作者基于pytorch做的small extension library of highly optimized sparse update (scatter and segment) operations
scatter_add_
是pytorch中实现的函数,上述函数很多是基于此所作,只不过当前函数侧重于矩阵的计算,而前者侧重于图相关的计算
scatter_add_
文字解释
scatter_add_
是scatter
的一个例子,pytorch对scatter函数的解释如下:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
>>>self.scatter_(dim, index, src, reduce)
- dim即维度,是对于self而言的,即在self的哪一dim进行操作
- index是索引,即要在self的哪一index进行操作 index的维度可以小于等于src,如果二者维度相同,则相当于将src的每一个数字都加到self的对应index上;如果index维度小,例如src: shape[5,3], index: shape[3,2]则代表只有src[:3,:2]的数字参与了操作
- src是待操作的源数字,比较好理解
- reduce代表操作的方式,none代表直接赋值,add则是+=,multiply是*= 因此scatter的意思就是 将src中前index部分的数字以一定的方式scatter(散布)到self中
以代码和对应图像为例对上述进行解释
src = torch.arange(1, 11).reshape((2, 5))
src
>>>tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
>>>tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
>>>tensor([[4, 2, 3, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
图中相同颜色的填充代表同一个位置,线条颜色则代表数字的分配,dim指对应self的维度;index分别应该是self的[dim,[index]],对应src待操作的数字应该是src[:index.shape]
总结
按照上述图示来看,需要注意的几点就是index的数值和维度分别对应的是self和src的取值
scatter_add
有了上述的理解,对于torch_scatter中的scatter_add更好理解了
src = torch.arange(1, 11).reshape((2, 5))
index = torch.tensor([[0,1,2,0,3],[0,1,1,2,2]])
torch_scatter.scatter_add(src, index)
>>>
tensor([[ 5, 2, 3, 5],
[ 6, 15, 19, 0]])
torch_scatter.scatter_add(src, index, dim=0)
>>>
tensor([[ 7, 0, 0, 4, 0],
[ 0, 9, 8, 0, 0],
[ 0, 0, 3, 9, 10],
[ 0, 0, 0, 0, 5]])
需要注意的几点:
- dim默认为-1
- index的值代表的是输出的维度,比如最大为100则输出的dim对应的维度为101
- 源码中开始会做一个broadcast将维度扩展
对于此函数,主要知道其应用场景:
scatter_add(edge_weight, edge_index[1], dim=0)
其意义就是将每个target node的与其邻接节点的边的权重之求和,最终得到的输出维度是节点数目;如果weight是0或者1,则得到的是degree,如果选择的是target节点则是入度,否则是出度。