Graph Neural Network
lecture 6,7,8详细介绍了图表示学习中的深度学习方法。之前介绍过Node Embedding,但是都是基于一些很“shallow”的特征,GNN可以帮助我们更高效地学习到更好的node、link、graph embedding。课程中所讲到的GNN都是spatial-based,也就是模型的结构是基于结点地空间特征,具体来说就是当前结点地embedding由它的neighbor得来,而spatial-based GNN遵循的一种模式叫做Message + Aggregate。
A Single Layer
GNN中的每一层遵循的是Message + Aggregate模式,不同的message passing和aggregate方式衍生出了不同的GNN模型,如:GCN、GraphSAGE、GAT等等。
Message Computation
首先来看如何计算message。每个结点都有自己的message,我们设
m
u
(
l
)
m_{u}^{(l)}
mu(l)表示结点
u
u
u在第
l
l
l层的message,
h
u
(
l
)
h_{u}^{(l)}
hu(l)表示结点
u
u
u在第
l
l
l层的embedding,那么message的计算公式为:
m
u
(
l
)
=
M
S
G
(
l
)
(
h
u
(
l
−
1
)
)
m_{u}^{(l)}\ =\ MSG^{(l)}(h_{u}^{(l-1)})
mu(l) = MSG(l)(hu(l−1))
其中,
M
S
G
(
l
)
MSG^{(l)}
MSG(l)表示第
l
l
l层的message function,选择有很多,最直接的就是乘一个参数矩阵
W
(
l
)
W^{(l)}
W(l),于是
m
u
(
l
)
=
W
(
l
)
h
u
(
l
−
1
)
m_{u}^{(l)}\ =\ W^{(l)}h_{u}^{(l-1)}
mu(l) = W(l)hu(l−1)。由于我们在更新每个结点的embedding时,也希望能把当前结点的message考虑进来,而不单单是考虑它neighbor的message,因此我们再定义一个参数矩阵
U
(
l
)
U^{(l)}
U(l)用于计算当前结点的message,即
m
v
(
l
)
=
U
(
l
)
h
v
(
l
−
1
)
m^{(l)}_{v}=U^{(l)}h^{(l-1)}_{v}
mv(l)=U(l)hv(l−1)。
Aggregate
Aggregate部分是将当前结点信息和其neighbor信息进行结合,得到当前结点的embedding,写成generalized的式子就是
h
u
(
l
)
=
A
G
G
(
l
)
(
{
m
v
(
l
)
,
v
∈
N
(
u
)
}
,
m
u
(
l
)
)
h_{u}^{(l)}\ =\ AGG^{(l)}(\{m_v^{(l)}, v \in N(u)\}, m_{u}^{(l)})
hu(l) = AGG(l)({mv(l),v∈N(u)},mu(l))
这里的
A
G
G
AGG
AGG就表示aggregate function。常见的aggregate function有Sum、Mean、Max等等.
Variants of GNN
根据上述的Message + Aggregate模式,我们就可以来分析一下几个经典的GNN
GCN
h u ( l ) = σ ( W ( l ) ∑ v ∈ N ( u ) h v ( l − 1 ) ∣ N ( u ) ∣ ) h^{(l)}_{u}\ =\ \sigma(W^{(l)}\sum_{v \in N(u)} \frac{h^{(l-1)}_{v}}{|N(u)|}) hu(l) = σ(W(l)v∈N(u)∑∣N(u)∣hv(l−1))
Message
从公式中可以看出,GCN的message function是
W
(
l
)
h
v
(
l
−
1
)
∣
N
(
u
)
∣
W^{(l)}\frac{h_{v}^{(l-1)}}{|N(u)|}
W(l)∣N(u)∣hv(l−1)
Aggregate
GCN中的Aggregate function使用的是Sum
GraphSAGE
h u ( l ) = σ ( W ( l ) ⋅ C O N C A T ( h u ( l − 1 ) , A G G ( { h v ( l − 1 ) , v ∈ N ( u ) } ) ) h_{u}^{(l)}\ =\ \sigma(W^{(l)} \cdot CONCAT(h_{u}^{(l-1)},AGG(\{h_{v}^{(l-1)}, v \in N(u)\})) hu(l) = σ(W(l)⋅CONCAT(hu(l−1),AGG({hv(l−1),v∈N(u)}))
GraphSAGE中的Aggregate分为两个部分,一个是对于neighbor embedding W ( l ) h v ( l − 1 ) W^{(l)}h_{v}^{(l-1)} W(l)hv(l−1)的Aggregate,这一步的aggregate function可以有多种选择;另一个就是当前结点message与上一步aggregate得到的结果进行aggregate,这里选用的是concatenation。回到第一步的aggregate function,比较常见的选择如下:
- Mean:简单的求一个均值,类似于GCN中的操作
- Pool:先对每个neighbor embedding做一个transformation,然后再用mean-pooling或者max-pooling
- LSTM:也可以把neighbor embedding当作序列信息然后用LSTM进行aggregate
GraphSAGE中还有一个小trick:对每一层每个结点的embedding进行 L 2 L_2 L2 normalization,这种做法在一些情况下能够提高performance
GAT
GAT是在GNN中引入了注意力机制。一个很intuitive的事实是,当我们在对一个结点的neighbor message进行aggregate的时候,每个neighbor的重要程度应该是不一样的。因此GAT中用
α
u
v
\alpha_{uv}
αuv表示
v
v
v的message对
u
u
u的一个权重
h
u
(
l
)
=
σ
(
∑
v
∈
N
(
u
)
α
u
v
W
(
l
)
h
v
(
l
−
1
)
)
h_{u}^{(l)}\ =\ \sigma(\sum_{v \in N(u) }\alpha_{uv}W^{(l)}h_{v}^{(l-1)})
hu(l) = σ(v∈N(u)∑αuvW(l)hv(l−1))
α
u
v
\alpha_{uv}
αuv的计算方式如下:
e
u
v
=
f
a
t
t
e
n
t
i
o
n
(
U
(
l
)
h
u
(
l
−
1
)
,
W
(
l
)
h
v
(
l
−
1
)
)
α
u
v
=
e
x
p
(
e
u
v
)
∑
v
′
∈
N
(
u
)
e
x
p
(
e
u
v
′
)
e_{uv}\ =\ f_{attention}(U^{(l)}h_{u}^{(l-1)}, W^{(l)}h_{v}^{(l-1)}) \\ \alpha_{uv}\ = \ \frac{exp(e_{uv})}{\sum_{v' \in N(u)}exp(e_{uv'})}
euv = fattention(U(l)hu(l−1),W(l)hv(l−1))αuv = ∑v′∈N(u)exp(euv′)exp(euv)
我们还可以像transformer一样使用multi-head attention
h
u
(
l
)
[
i
]
=
σ
(
∑
v
∈
N
(
u
)
α
u
v
(
i
)
W
(
l
)
h
v
(
l
−
1
)
)
h
u
(
l
)
=
A
G
G
(
{
h
u
(
l
)
[
i
]
,
i
=
1
,
2
,
…
,
n
}
)
h_{u}^{(l)}[i]\ =\ \sigma(\sum_{v \in N(u) }\alpha^{(i)}_{uv}W^{(l)}h_{v}^{(l-1)}) \\ h_{u}^{(l)}\ =\ AGG(\{h_{u}^{(l)}[i], i = 1,2,\dots,n\})
hu(l)[i] = σ(v∈N(u)∑αuv(i)W(l)hv(l−1))hu(l) = AGG({hu(l)[i],i=1,2,…,n})
GAT的优势如下:
- 让我们能够捕获到不同neighbor的重要程度
- 计算高效,这一点是attention共有的,可以并行计算
- 存储高效
- 具有Inductive的能力,不依赖于全局的结构,关注的是局部信息
Stacking Layers of GNN
讲完单层的GNN结构,下一步就应该增加网络的深度了。在CV或者NLP的一些模型中,通常我们把模型加的越深越好,模型越深,表达能力大概率会越强。但是GNN有所不同,如果一味的加深网络深度,那会带来一个Over-Smoothing的问题。
Over-Smoothing
首先介绍一个概念叫做Receptive Field:决定一个结点embedding的结点集合。由于每一层我们是用每个结点的neighbor来更新embedding,因此随着网络深度变大,每个结点的receptive field也会变得越来越大
这张图非常直观地展示了黄色结点receptive field逐渐变大地过程。因此,当网络太深时,每个结点的receptive field会出现很大程度的overlap,这就会导致每个结点的embedding最后会趋于一致,这就是over-smoothing problem。
为了解决over-smoothing问题,lecture中提到了两种解决方案:
- 增加单层GNN的表达能力。既然深度不能太大,那我们就可以在每一层GNN上尽可能的提升performance
- 添加残差结构,在层与层之间加skip connection
Expressivity of GNN
GNN模型的表达能力,简单来讲,就是模型能够区分不同结构的能力。首先介绍了一个概念叫做:computation graph
Computation Graph
每个结点的computation graph是由它的local neighborhood决定的,模型在做Aggregate的时候就是基于每个结点的计算图。以下图为例
这分别是5个结点的computation graph。我们可以看出,不同的local neighbor structure会带来不同的computation graph,而如果我们的模型对于不同的computation graph能够生成不同的embedding,我们就能区分不同的结点。
Graph Isomorphic Network(GIN)
根据上面的描述,我们发现GNN的表达能力的一大关键就是aggregate function的选择,我们最希望的就是对于不同的computation graph,我们aggregate出来的结果也是不同的,换言之,我们希望aggregate function是injective的。
先前我们所看到的各种aggregate function(sum、max、mean)其实都有一些failure cases,做不到injective。
因此GIN模型提出使用一层MLP来做aggregate,之所以使用MLP,是因为Universal Approximation Theorem:参数量达到一定程度的单层MLP可以拟合任意的函数。因此,GIN中的aggregate function就变成了如下形式:
M
L
P
Φ
(
∑
x
∈
S
M
L
P
f
(
x
)
)
MLP_{\Phi}(\sum_{x \in S}MLP_{f}(x))
MLPΦ(x∈S∑MLPf(x))
同时,GIN利用了WL graph kernel的方法来实现message passing,并对结点embedding进行update,公式为:
h
u
(
k
+
1
)
=
G
I
N
C
o
n
v
(
h
u
(
k
)
,
{
h
v
(
k
)
,
v
∈
N
(
u
)
}
)
=
M
L
P
Φ
(
(
1
+
ϵ
)
⋅
M
L
P
f
(
h
u
(
k
)
)
+
∑
v
∈
N
(
u
)
M
L
P
f
(
h
v
(
k
)
)
)
h^{(k+1)}_{u}\ =\ GINConv(h^{(k)}_u, \{h^{(k)}_{v},v \in N(u)\})\ \\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ =\ MLP_{\Phi}((1+\epsilon) \cdot MLP_{f}(h_{u}^{(k)})+\sum_{v \in N(u)}MLP_{f}(h^{(k)}_{v}))
hu(k+1) = GINConv(hu(k),{hv(k),v∈N(u)}) = MLPΦ((1+ϵ)⋅MLPf(hu(k))+v∈N(u)∑MLPf(hv(k)))
这里的
k
k
k不再代表层数,而是代表WL test的迭代次数。