介绍
信息传递网络(Message Passing Neural Networks, MPNNs)是由Gilmer等人提出的一种图神经网络通用计算框架。原文以量子化学为例,根据原子的性质(对应节点特征)和分子的结构(对应边特征)预测了13种物理化学性质。查看论文原文请点击这里。
机制
理论
MPNN的前向传播包括两个阶段,第一个阶段称为message passing(信息传递)
阶段,第二个阶段称为readout(读取)
阶段。定义一张图
G
=
(
V
,
E
)
G=(V,E)
G=(V,E),其中
V
V
V是所有节点,
E
E
E是所有边。
信息传递阶段
message passing
阶段会执行多次信息传递过程。对于一个特定的节点v
,我先给出公式。
m
v
t
+
1
=
∑
w
∈
N
(
v
)
M
t
(
h
v
t
,
h
w
t
,
e
v
w
)
(1)
m_v^{t+1}=\sum_{w\in N(v)}M_t\left( h_v^{t},h_w^{t},e_{vw} \right)\tag{1}
mvt+1=w∈N(v)∑Mt(hvt,hwt,evw)(1)
h
v
t
+
1
=
U
t
(
h
v
t
,
m
v
t
+
1
)
(2)
h_v^{t+1}=U_t\left(h_v^{t},m_v^{t+1}\right)\tag{2}
hvt+1=Ut(hvt,mvt+1)(2)
其中,在公式
(
1
)
(1)
(1)中,
m
v
t
+
1
m_v^{t+1}
mvt+1是结点v
在t+1
时间步所接收到的信息,
N
(
v
)
N(v)
N(v)是结点v
的所有邻结点,
h
v
t
h_v^{t}
hvt是结点v
在t
时间步的特征向量,
e
v
w
e_{vw}
evw是结点v
和w
的边特征,
M
t
M_t
Mt是消息函数。该公式的意义是节点v
收到的信息来源于节点v
本身状态(
h
v
t
h_v^{t}
hvt),周围的节点状态(
h
w
t
h_w^{t}
hwt)和与之相连的边特征(
e
v
w
e_{vw}
evw)。生成信息后,就需要对结点进行更新。
在公式 ( 2 ) (2) (2)中, U t U_t Ut是结点更新函数,该函数把原节点状态 h v t h_v^{t} hvt和信息 m v t + 1 m_v^{t+1} mvt+1作为输入,得到新的节点状态 h v t + 1 h_v^{t+1} hvt+1。熟悉RNN的同学可能会眼熟这个公式,这个更新函数和RNN里的更新函数是一样的。后面我们也可以看到,我们可以用GRU或LSTM来表示 U t U_t Ut。
最后再强调一下时间步的概念。计算完一次 ( 1 ) (1) (1)和 ( 2 ) (2) (2)算一个时间步,因此如果时间步设为 T T T,上述两个公式会各运行 T T T次,最终得到的结果是 h v T h_v^{T} hvT。
读取阶段
readout
阶段使用读取函数
R
R
R计算基于整张图的特征向量,可以表示为
y
^
=
R
(
{
h
v
T
∣
v
∈
G
}
)
(3)
\hat{y}=R\left(\{h_v^T|v \in G \} \right)\tag{3}
y^=R({hvT∣v∈G})(3)
其中,
y
^
\hat{y}
y^是最终的输出向量,
R
R
R是读取函数,这个函数有两个要求:1、要可以求导。2、要满足置换不变性(结点的输入顺序不改变最终结果,这也是为了保证MPNN对图的同构有不变性)
实际案例
在MPNN的框架下,我们可以自定义消息函数、更新函数和读取函数,下面我举一个实际的案例,也是这篇文章所提及的门控图神经网络(Gated Graph Neural Networks, GG-NN)。这里,信息函数、结点更新函数和读取函数被定义为
M
t
(
h
v
t
,
h
w
t
,
e
v
w
)
=
A
e
v
w
h
w
t
(4)
M_t\left( h_v^{t},h_w^{t},e_{vw} \right)=A_{e_{vw}}h_w^t\tag{4}
Mt(hvt,hwt,evw)=Aevwhwt(4)
U
t
(
h
v
t
,
m
v
t
+
1
)
=
G
R
U
(
h
v
t
,
m
v
t
+
1
)
(5)
U_t\left(h_v^{t},m_v^{t+1}\right)=GRU\left(h_v^{t},m_v^{t+1}\right)\tag{5}
Ut(hvt,mvt+1)=GRU(hvt,mvt+1)(5)
R
=
∑
v
∈
V
σ
(
i
(
h
v
(
T
)
,
h
v
0
)
)
⊙
(
j
(
h
v
(
T
)
)
)
(6)
R=\sum_{v\in V}\sigma\left(i\left(h_v^{(T)},h_v^0\right)\right)\odot \left(j\left(h_v^{(T)}\right)\right)\tag{6}
R=v∈V∑σ(i(hv(T),hv0))⊙(j(hv(T)))(6)
消息函数
(
4
)
(4)
(4)中,矩阵
A
e
v
w
A_{e_{vw}}
Aevw决定了图中的结点是如何与其他结点进行相互作用的,一条边对应一个矩阵。但是这个函数描述得有些笼统。GGNN文章中的公式更清晰一些,如下所示
a
v
(
t
)
=
A
v
:
T
[
h
1
(
t
−
1
)
T
,
h
2
(
t
−
1
)
T
,
.
.
.
,
h
∣
V
∣
(
t
−
1
)
T
]
T
+
b
(7)
a_v^{\left(t\right)}=A_{v:}^T\left[h_1^{(t-1)^T},h_2^{(t-1)^T},...,h_{|V|}^{(t-1)^T}\right]^T+b\tag{7}
av(t)=Av:T[h1(t−1)T,h2(t−1)T,...,h∣V∣(t−1)T]T+b(7)
其中,
a
v
(
t
)
a_v^{\left(t\right)}
av(t)是结点v
在t
时刻接收到的信息向量,和我们之前定义的
m
v
t
+
1
m_v^{t+1}
mvt+1是一样的,只是换了些字母。
h
(
t
−
1
)
h^{(t-1)}
h(t−1)表示节点在t-1
个时间步的状态,因此
[
h
1
(
t
−
1
)
T
,
h
2
(
t
−
1
)
T
,
.
.
.
,
h
∣
V
∣
(
t
−
1
)
T
]
T
\left[h_1^{(t-1)^T},h_2^{(t-1)^T},...,h_{|V|}^{(t-1)^T}\right]^T
[h1(t−1)T,h2(t−1)T,...,h∣V∣(t−1)T]T把每个结点的状态拼接在一个维度上,维度大小为
D
∣
V
∣
D|V|
D∣V∣,
b
b
b是偏置项,至于
A
v
:
A_{v:}
Av:,我们先看下面这张图
这里有一个边的特征矩阵
A
A
A,矩阵
A
A
A考虑了边的方向,因此它是由out
和in
两个部分拼接而成,图中的不同字母代表了不同的相互作用类型(也可以视为每条边的特征,注意每一条边的特征维度都是
(
D
,
D
)
(D, D)
(D,D),而不是我们常见的一维向量,在实际应用中,如果边的初始特征维度不是
D
D
D,可以进行embedding
或线性变换到
D
×
D
D\times D
D×D维,再reshape
到
(
D
,
D
)
(D, D)
(D,D)),最终的维度是
(
D
∣
V
∣
,
2
D
∣
V
∣
)
(D|V|,2D|V|)
(D∣V∣,2D∣V∣),其中
∣
V
∣
|V|
∣V∣是结点个数。有了矩阵
A
A
A之后,我们需要针对某一个结点选出“两列”
(并非真正意义上的两列)。以2号结点作为v
结点为例,我们在Outgoing Edges
和Incoming Edges
中分别找到2号结点,再把这两列拼接起来,得到一个维度是
(
D
∣
V
∣
,
2
D
)
(D|V|,2D)
(D∣V∣,2D)的矩阵
A
v
:
A_{v:}
Av:。将该矩阵的转置与所有节点的状态拼接成的列向量相乘,最终得到一个维度为
2
D
2D
2D的信息向量
a
v
(
t
)
a_v^{\left(t\right)}
av(t)。而对于无向图而言,只需要考虑一半的情况就行了。
结点更新函数
(
5
)
(5)
(5)是GRU
,对GRU
不熟悉的同学可以看一下这方面的知识,在此就不再多做解释了。
读取函数
(
6
)
(6)
(6)看起来是较为复杂的,我们可以拆开来看。首先
⊙
\odot
⊙表示逐元素相乘,
i
i
i和
j
j
j分别表示一个全连接神经网络,并且在
i
i
i的外面又套了一层sigmoid
函数,用符号
σ
\sigma
σ表示。对于神经网络
i
i
i而言,输入是结点的初始状态和最终状态,因此输入维度是2 * in_dim
,而对于神经网络
j
j
j而言,输入只有结点的最终状态,因此输入维度是in_dim
。但是这两个神经网络的输出维度是一样的,这样才能逐元素相乘。再往深入一点讲,这里包含了self attention
机制,就是在读取阶段要注意
该节点最初的特征。
代码
我分别找到了Pytorch和Tensorflow的实现,以后有时间我会分析一下Pytorch版的实现过程。
Pytorch版
Tensorflow版(原作者)