这块的内容介绍相当的少,官网给出的doc也描述的没有非常清晰,以至于我刚开始学的时候非常云里雾里。后来慢慢明白了之后,才慢慢看懂官网的doc。
首先给出官网的doc:
这一块我们要注意几个点:方块代表的是聚合邻居信息的操作,比方说sum,mean,max
,γ代表的是结合上一跳邻居信息和节点信息的更新操作,φ代表的是接受邻居信息的操作。
了解了这些知识,再看PyG的注解就容易了,我们需要实现的是这几个函数,就可以自定义GNN的一层:
总结下来就这几点:
- MessagePassing定义邻居信息的聚合器,即图2的Aggregation
- MessagePassing.propagate()是开始传播信息的函数,下面的message()函数和update()函数中需要的参数都要被包含进这个函数的参数中。
- message()函数是定义 x j x_j xj如何传播到节点 i i i的函数,即图2的φ函数。比方说,在GAT中, x j x_j xj传播到节点 i i i需要乘上注意力,那么就应该在这个函数里定义。值得注意的是,所有在propagate()中定义过的参数都可以在message()中直接使用,而且还可以在参数后面加上
_j
指定节点。比方说,我们在在propagate()中定义了参数x
表示所有节点的features,x大小为[N,features]
,那么我们就可以使用x_i
,x_j
表示节点j的特征,大小是[E,features]
,为什么大小是[E,features]
呢,我理解这里使用x_j
代表着所有边的源节点的特征,x_i
代表着所有边的目标节点的特征。 - 更新节点embedding的函数,对应于图二的γ,和message()函数一样可以直接使用propagate()中的参数,默认参数是
aggr_out
,这个参数代表着经过上一步的message()
之后所有节点获得的邻居节点的表示,大小是[N,in_channels]
。
MessagePassing—定义邻居信息聚合器
propagate():传播信息,指定要传播的信息和边的集合
message():接收信息,返回 x j x_j xj传到 x i x_i xi之后变成了啥
update():更新节点 i i i表示。
来看一个小栗子:
我们对照图2可以看出,此时的γ是没有的,也就是说不需要实现update()
函数,Aggregation是max,也就是说我们的MessagePassing需要定义为max,propagate()正常传播就行,message()里面比较复杂,在接收到