一.前言
众所周知,图神经网络可以从空域或谱域来对其进行研究。其中,空域角度主要借助消息传播机制来构建GNN。本文主要介绍的是消息传递机制,为下篇文章具体介绍PyG中是如何实现消息传播机制做好铺垫。
二.消息传递框架概述
消息传递是Gilmer等在Neural Message Passing for Quantum Chemistry中提出来的从空域角度定义GNN的范式。假设
x
i
(
k
−
1
)
∈
R
F
\mathbf{x}^{(k-1)}_i \in \mathbb{R}^F
xi(k−1)∈RF表示节点
i
i
i在第
k
−
1
k-1
k−1层的特征,
e
j
,
i
∈
R
D
\mathbf{e}_{j,i} \in \mathbb{R}^D
ej,i∈RD表示节点
j
j
j到节点
i
i
i的边上的特征,则消息传播机制可以用如下公式来描述:
x
i
(
k
)
=
γ
(
k
)
(
x
i
(
k
−
1
)
,
□
j
∈
N
(
i
)
ϕ
(
k
)
(
x
i
(
k
−
1
)
,
x
j
(
k
−
1
)
,
e
j
,
i
)
)
(1)
\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right) \tag{1}
xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i))(1)
在消息传播机制中,主要分为三大步骤:消息生成、消息聚合、消息更新。
2.1 消息生成与传播
在本阶段中,每个节点将生成自己的消息,然后向自己的邻居节点”传播“自己的消息,也就是公式(1)中的:
ϕ
(
k
)
(
x
i
(
k
−
1
)
,
x
j
(
k
−
1
)
,
e
j
,
i
)
\phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right)
ϕ(k)(xi(k−1),xj(k−1),ej,i)
其中,
ϕ
(
k
)
\phi^{(k)}
ϕ(k)表示可微函数,例如MLP。在生成消息的过程中,可能会用到:
- 节点自己当前的特征( x i ( k − 1 ) \mathbf{x}_i^{(k-1)} xi(k−1))
- 节点邻居当前的特征( x j ( k − 1 ) \mathbf{x}_j^{(k-1)} xj(k−1))
- 节点与其邻居间边的特征( e j , i \mathbf{e}_{j,i} ej,i)
当然上述三者并不都是必须的,具体使用什么来生成节点的消息取决于GNN的构建者。
2.2 消息聚合
在本阶段,每个节点会聚合来自邻居的消息,也就是公式(1)中的:
□
j
∈
N
(
i
)
(
Message
)
\square_{j \in \mathcal{N}(i)}(\text{Message})
□j∈N(i)(Message)
其中
Message
\text{Message}
Message指代2.1节中每个节点的消息,
N
(
i
)
\mathcal{N}(i)
N(i)表示节点
i
i
i的邻域,
□
\square
□表示可微的、转置不变(permutation invariant)函数。转置不变指聚合邻居的消息的结果与邻居的聚合顺序无关,常见的包括sum,max,mean
。
2.3 消息更新
在本阶段,每个节点利用聚合自邻居节点的消息来生成自己的消息,也就是公式(1)中的:
γ
(
k
)
(
x
i
(
k
−
1
)
,
NeighborMsg
)
\gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \text{NeighborMsg}\right)
γ(k)(xi(k−1),NeighborMsg)
其中
N
e
i
g
h
b
o
r
M
s
g
NeighborMsg
NeighborMsg指代2.2节中每个节点聚合自邻居的消息,
γ
(
k
)
\gamma^{(k)}
γ(k)也表示可微函数,例如MLP。
2.4 消息传递机制小结
上述的消息传播机制可以用下图概括:
其中图中的方框便是的便是聚合邻居的神经网络。
经过前面的介绍可知:空域角度定义的GNN间的不同之处便在于它们关于消息生成、消息聚合和消息更新的实现的不同。
三.消息传播机制的示例
为了方便理解上面的消息传递机制,本节主要展示GCN、GraphSAGE中的消息传播机制。
3.1 GCN中的消息传播机制
GCN虽然是从谱域角度定义的,但同样从空域角度来对其进行解释,其所对应的消息传播机制如下:
h
i
(
l
+
1
)
=
σ
(
1
∣
N
(
i
)
∣
∑
j
∈
N
(
i
)
1
∣
N
(
j
)
∣
h
j
(
l
)
W
(
l
)
)
h_i^{(l+1)} = \sigma(\frac{1}{\sqrt{|\mathcal{N}(i)|}}\sum_{j\in\mathcal{N}(i)}\frac{1}{\sqrt{|\mathcal{N}(j)|}}h_j^{(l)}W^{(l)})
hi(l+1)=σ(∣N(i)∣1j∈N(i)∑∣N(j)∣1hj(l)W(l))
生成消息:
1
∣
N
(
j
)
∣
h
j
(
l
)
W
(
l
)
\frac{1}{\sqrt{|\mathcal{N}(j)|}}h_j^{(l)}W^{(l)}
∣N(j)∣1hj(l)W(l)
消息聚合:
σ
(
1
∣
N
(
i
)
∣
∑
j
∈
N
(
i
)
Message
)
\sigma(\frac{1}{\sqrt{|\mathcal{N}(i)|}}\sum_{j\in\mathcal{N}(i)}{\text{Message}})
σ(∣N(i)∣1j∈N(i)∑Message)
其中
σ
\sigma
σ指非线性激活,通常为ReLU。
消息更新:消息聚合来的消息。
3.2 GraphSAGE中的消息传播机制
GraphSAGE中的消息传播机制如下所示:
h
i
(
l
+
1
)
=
σ
(
W
(
l
)
⋅
CONCAT
(
h
i
(
l
)
,
A
G
G
(
{
h
j
(
l
)
,
∀
j
∈
N
(
i
)
}
)
)
)
h_{i}^{(l + 1)}=\sigma\left(W^{(l)} \cdot \operatorname{CONCAT}\left(h_{i}^{(l)}, \mathrm{AGG}\left(\left\{h_{j}^{(l)}, \forall j \in N(i)\right\}\right)\right)\right)
hi(l+1)=σ(W(l)⋅CONCAT(hi(l),AGG({hj(l),∀j∈N(i)})))
其中
norm
\text{norm}
norm表示L2正则化。
消息生成:
h
j
h_j
hj
两阶段聚合(首先聚合邻居的消息,然后聚合自身的消息):
h
N
(
i
)
(
l
+
1
)
←
A
G
G
(
{
h
j
(
l
)
,
∀
j
∈
N
(
i
)
}
)
σ
(
W
(
l
)
⋅
CONCAT
(
h
i
(
l
)
,
h
N
(
i
)
(
l
+
1
)
)
)
h_{N(i)}^{(l + 1)} \leftarrow \mathrm{AGG}\left(\left\{h_{j}^{(l)}, \forall j \in N(i)\right\}\right) \\ \sigma\left(W^{(l)} \cdot \operatorname{CONCAT}\left(h_{i}^{(l)},h_{N(i)}^{(l + 1)}\right)\right)
hN(i)(l+1)←AGG({hj(l),∀j∈N(i)})σ(W(l)⋅CONCAT(hi(l),hN(i)(l+1)))
其中
AGG
\text{AGG}
AGG在GraphSAGE中有三种实现:Mean、Pool或LSTM。
消息更新:消息聚合来的消息。
四.结语
参考资料:
GNN中的消息传播机制是借助PyG之类的图神经网络框架来编写自己的消息传播GNN的基础,只有对其了解较为深刻,你才能更好的设计自己的GNN模型。