消息传递范式
在图神经网络中,为节点生成节点表征是图计算任务成功的关键。在此小节,本节学习基于神经网络的生成节点表征的范式——消息传递范式。
消息传递范式是一种聚合邻接节点信息来更新中心节点信息的范式,它将卷积算子推广到了不规则数据领域,实现了图与神经网络的连接。
消息传递范式包含三个步骤:
- 邻接节点信息变换
- 邻接节点信息聚合到中心节点
- 聚合信息变换
神经网络的生成节点表征的操作可称为节点嵌入(Node Embedding),节点表征也可以称为节点嵌入。
下图展示了基于消息传递范式的生成节点表征的过程:
- 在图的最右侧,B节点的邻接节点(A,C)的信息传递给了B,经过信息变换得到了B的嵌入,C、D节点同。
- 在图的中右侧,A节点的邻接节点(B,C,D)的之前得到的节点嵌入传递给了节点A;在图的中左侧,聚合得到的信息经过信息变换得到了A节点新的嵌入。
- 重复多次,我们可以得到每一个节点的经过多次信息变换的嵌入。这样的经过多次信息聚合与变换的节点嵌入就可以作为节点的表征,可以用于节点的分类。
Pytorch Geometric中的MessagePassing
基类
Pytorch Geometric(PyG)提供了MessagePassing
基类,它实现了消息传播的自动处理,继承该基类可使我们方便地构造消息传递图神经网络,我们只需定义函数
ϕ
\phi
ϕ,即message()
函数,和函数
γ
\gamma
γ,即update()
函数,以及使用的消息聚合方案,即aggr="add"
、aggr="mean"
或aggr="max"
。
MessagePassing
主要提供了以下方法:
MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
MessagePassing.propagate(edge_index, size=None, **kwargs)
MessagePassing.message(...)
MessagePassing.aggregate(...)
MessagePassing.message_and_aggregate(...)
MessagePassing.update(aggr_out, ...)
以上内容可参考The “MessagePassing” Base Class。
继承MessagePassing
类的GCNConv
GCNConv的数学公式定义如下**
x
i
(
k
)
=
∑
j
∈
N
(
i
)
∪
{
i
}
1
deg
(
i
)
⋅
deg
(
j
)
⋅
(
Θ
⋅
x
j
(
k
−
1
)
)
,
\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right),
xi(k)=j∈N(i)∪{i}∑deg(i)⋅deg(j)1⋅(Θ⋅xj(k−1)),
这个公式可以分为以下几个步骤:
- 向邻接矩阵添加自环边。
- 线性转换节点特征矩阵。
- 计算归一化系数。
- 归一化 j j j中的节点特征。
- 将相邻节点特征相加("求和 "聚合)。
GCNConv
继承了MessagePassing
并以"求和"作为领域节点信息聚合方式。该层的所有逻辑都发生在其forward()
方法中。
propagate
函数
在message
函数中希望接受到哪些数据,就要在propagate
函数的调用中传递哪些参数。
message
函数
message
函数接收两个参数x_j
和norm
,而propagate
函数被传递三个参数edge_index, x=x, norm=norm
。由于x
是Data
类的属性,且message
函数接收x_j
参数而不是x
参数,所以在propagate
函数被调用,message
函数被执行之前,一项额外的操作被执行,该项操作根据edge_index
参数从x
中分离出x_j
。事实上,在message
函数里,当参数是Data
类的属性时,我们可以在参数名后面拼接_i
或_j
来指定要接收源节点的属性或是目标节点的属性。
aggregate
函数
在前面的例子中增加如下的aggregate
函数,通过观察运行结果我们发现,我们覆写的aggregate
函数被调用,同时在super(GCNConv, self).__init__(aggr='add')
中传递给aggr
参数的值被存储到了self.aggr
属性中。
message_and_aggregate
函数
在一些例子中,消息传递与消息聚合可以融合在一起,这种情况我们通过覆写message_and_aggregate
函数来实现:
update
函数
update
函数接收聚合的输出作为第一个参数,并接收传递给propagate
的任何参数。