最近在做图结构相关的算法,scatter能把邻接矩阵里的信息修改,或者把邻居分组算个sum或者reduce,挺方便的,简单整理一下。
torch.scatter 与 tensor._scatter
Pytorch自带的函数,用来将作为src
的tensor根据index
的描述填充到input
中,形式如下:
ouput = torch.scatter(input, dim, index, src)
# 或者是
input.scatter_(dim, index, src)
两个方法的功能是相同的,而带下划线的_scatter
方法是将原tensor input
直接修改了,不带的则会返回一个新的tensor output
,input
不变。
其中dim
决定index
对应值是沿着哪个维度进行修改。而src
为数据来源,当其为tensor张量时,shape要和index相同,这样index中每个元素都能对应src
中对应位置的信息。
理解scatter
方法主要是要理解index
实现的src
和input
之间的位置对应关系,举个例子:
dim = 0
index = torch.tensor(
[[0, 2, 2],
[2, 1, 0]]
)
dim
为0时,遵循的映射原则为:input[index[i][j]][j] = src[i][j]
.
也就是说,将位置 (i, j) 中dim
对应的位置改为 index[i][j] 的值。如位置(1,0),index[1][0]为2,则映射后的位置为(2,0),意味着input
中(2,0)的位置被更改为src
中(1,0)位置的值。
我个人形象理解是这些值会沿着dim方向滑动,上面例子中src[1][0]位置的值滑到2,成为input中的新值,这样理解起来更形象一点。
基本理解了上面这个例子,多维情况和不同dim的情况都可以类推了。
需要注意:src和input的dtype需要相同,不然会报Expected self.dtype to be equal to src.dtype
,不一样就先转换再使用。
t = torch.arange(6).view(2, 3)
t = t.to(torch.float32)
print(t)
output = torch.scatter(torch.zeros((3