CS224W Lecture6-8笔记

Graph Neural Network

lecture 6,7,8详细介绍了图表示学习中的深度学习方法。之前介绍过Node Embedding,但是都是基于一些很“shallow”的特征,GNN可以帮助我们更高效地学习到更好的node、link、graph embedding。课程中所讲到的GNN都是spatial-based,也就是模型的结构是基于结点地空间特征,具体来说就是当前结点地embedding由它的neighbor得来,而spatial-based GNN遵循的一种模式叫做Message + Aggregate。

A Single Layer

GNN中的每一层遵循的是Message + Aggregate模式,不同的message passing和aggregate方式衍生出了不同的GNN模型,如:GCN、GraphSAGE、GAT等等。

在这里插入图片描述

Message Computation

首先来看如何计算message。每个结点都有自己的message,我们设 m u ( l ) m_{u}^{(l)} mu(l)表示结点 u u u在第 l l l层的message, h u ( l ) h_{u}^{(l)} hu(l)表示结点 u u u在第 l l l层的embedding,那么message的计算公式为:
m u ( l )   =   M S G ( l ) ( h u ( l − 1 ) ) m_{u}^{(l)}\ =\ MSG^{(l)}(h_{u}^{(l-1)}) mu(l) = MSG(l)(hu(l1))
其中, M S G ( l ) MSG^{(l)} MSG(l)表示第 l l l层的message function,选择有很多,最直接的就是乘一个参数矩阵 W ( l ) W^{(l)} W(l),于是 m u ( l )   =   W ( l ) h u ( l − 1 ) m_{u}^{(l)}\ =\ W^{(l)}h_{u}^{(l-1)} mu(l) = W(l)hu(l1)。由于我们在更新每个结点的embedding时,也希望能把当前结点的message考虑进来,而不单单是考虑它neighbor的message,因此我们再定义一个参数矩阵 U ( l ) U^{(l)} U(l)用于计算当前结点的message,即 m v ( l ) = U ( l ) h v ( l − 1 ) m^{(l)}_{v}=U^{(l)}h^{(l-1)}_{v} mv(l)=U(l)hv(l1)

Aggregate

Aggregate部分是将当前结点信息和其neighbor信息进行结合,得到当前结点的embedding,写成generalized的式子就是
h u ( l )   =   A G G ( l ) ( { m v ( l ) , v ∈ N ( u ) } , m u ( l ) ) h_{u}^{(l)}\ =\ AGG^{(l)}(\{m_v^{(l)}, v \in N(u)\}, m_{u}^{(l)}) hu(l) = AGG(l)({mv(l),vN(u)},mu(l))
这里的 A G G AGG AGG就表示aggregate function。常见的aggregate function有Sum、Mean、Max等等.

Variants of GNN

根据上述的Message + Aggregate模式,我们就可以来分析一下几个经典的GNN

GCN

h u ( l )   =   σ ( W ( l ) ∑ v ∈ N ( u ) h v ( l − 1 ) ∣ N ( u ) ∣ ) h^{(l)}_{u}\ =\ \sigma(W^{(l)}\sum_{v \in N(u)} \frac{h^{(l-1)}_{v}}{|N(u)|}) hu(l) = σ(W(l)vN(u)N(u)hv(l1))

Message

从公式中可以看出,GCN的message function是
W ( l ) h v ( l − 1 ) ∣ N ( u ) ∣ W^{(l)}\frac{h_{v}^{(l-1)}}{|N(u)|} W(l)N(u)hv(l1)

Aggregate

GCN中的Aggregate function使用的是Sum

GraphSAGE

h u ( l )   =   σ ( W ( l ) ⋅ C O N C A T ( h u ( l − 1 ) , A G G ( { h v ( l − 1 ) , v ∈ N ( u ) } ) ) h_{u}^{(l)}\ =\ \sigma(W^{(l)} \cdot CONCAT(h_{u}^{(l-1)},AGG(\{h_{v}^{(l-1)}, v \in N(u)\})) hu(l) = σ(W(l)CONCAT(hu(l1),AGG({hv(l1),vN(u)}))

GraphSAGE中的Aggregate分为两个部分,一个是对于neighbor embedding W ( l ) h v ( l − 1 ) W^{(l)}h_{v}^{(l-1)} W(l)hv(l1)的Aggregate,这一步的aggregate function可以有多种选择;另一个就是当前结点message与上一步aggregate得到的结果进行aggregate,这里选用的是concatenation。回到第一步的aggregate function,比较常见的选择如下:

  • Mean:简单的求一个均值,类似于GCN中的操作
  • Pool:先对每个neighbor embedding做一个transformation,然后再用mean-pooling或者max-pooling
  • LSTM:也可以把neighbor embedding当作序列信息然后用LSTM进行aggregate

GraphSAGE中还有一个小trick:对每一层每个结点的embedding进行 L 2 L_2 L2 normalization,这种做法在一些情况下能够提高performance

GAT

GAT是在GNN中引入了注意力机制。一个很intuitive的事实是,当我们在对一个结点的neighbor message进行aggregate的时候,每个neighbor的重要程度应该是不一样的。因此GAT中用 α u v \alpha_{uv} αuv表示 v v v的message对 u u u的一个权重
h u ( l )   =   σ ( ∑ v ∈ N ( u ) α u v W ( l ) h v ( l − 1 ) ) h_{u}^{(l)}\ =\ \sigma(\sum_{v \in N(u) }\alpha_{uv}W^{(l)}h_{v}^{(l-1)}) hu(l) = σ(vN(u)αuvW(l)hv(l1))
α u v \alpha_{uv} αuv的计算方式如下:
e u v   =   f a t t e n t i o n ( U ( l ) h u ( l − 1 ) , W ( l ) h v ( l − 1 ) ) α u v   =   e x p ( e u v ) ∑ v ′ ∈ N ( u ) e x p ( e u v ′ ) e_{uv}\ =\ f_{attention}(U^{(l)}h_{u}^{(l-1)}, W^{(l)}h_{v}^{(l-1)}) \\ \alpha_{uv}\ = \ \frac{exp(e_{uv})}{\sum_{v' \in N(u)}exp(e_{uv'})} euv = fattention(U(l)hu(l1),W(l)hv(l1))αuv = vN(u)exp(euv)exp(euv)
我们还可以像transformer一样使用multi-head attention
h u ( l ) [ i ]   =   σ ( ∑ v ∈ N ( u ) α u v ( i ) W ( l ) h v ( l − 1 ) ) h u ( l )   =   A G G ( { h u ( l ) [ i ] , i = 1 , 2 , … , n } ) h_{u}^{(l)}[i]\ =\ \sigma(\sum_{v \in N(u) }\alpha^{(i)}_{uv}W^{(l)}h_{v}^{(l-1)}) \\ h_{u}^{(l)}\ =\ AGG(\{h_{u}^{(l)}[i], i = 1,2,\dots,n\}) hu(l)[i] = σ(vN(u)αuv(i)W(l)hv(l1))hu(l) = AGG({hu(l)[i],i=1,2,,n})
GAT的优势如下:

  1. 让我们能够捕获到不同neighbor的重要程度
  2. 计算高效,这一点是attention共有的,可以并行计算
  3. 存储高效
  4. 具有Inductive的能力,不依赖于全局的结构,关注的是局部信息

Stacking Layers of GNN

讲完单层的GNN结构,下一步就应该增加网络的深度了。在CV或者NLP的一些模型中,通常我们把模型加的越深越好,模型越深,表达能力大概率会越强。但是GNN有所不同,如果一味的加深网络深度,那会带来一个Over-Smoothing的问题。

Over-Smoothing

首先介绍一个概念叫做Receptive Field:决定一个结点embedding的结点集合。由于每一层我们是用每个结点的neighbor来更新embedding,因此随着网络深度变大,每个结点的receptive field也会变得越来越大

在这里插入图片描述

这张图非常直观地展示了黄色结点receptive field逐渐变大地过程。因此,当网络太深时,每个结点的receptive field会出现很大程度的overlap,这就会导致每个结点的embedding最后会趋于一致,这就是over-smoothing problem。

为了解决over-smoothing问题,lecture中提到了两种解决方案:

  • 增加单层GNN的表达能力。既然深度不能太大,那我们就可以在每一层GNN上尽可能的提升performance
  • 添加残差结构,在层与层之间加skip connection

Expressivity of GNN

GNN模型的表达能力,简单来讲,就是模型能够区分不同结构的能力。首先介绍了一个概念叫做:computation graph

Computation Graph

每个结点的computation graph是由它的local neighborhood决定的,模型在做Aggregate的时候就是基于每个结点的计算图。以下图为例

在这里插入图片描述

这分别是5个结点的computation graph。我们可以看出,不同的local neighbor structure会带来不同的computation graph,而如果我们的模型对于不同的computation graph能够生成不同的embedding,我们就能区分不同的结点。

Graph Isomorphic Network(GIN)

根据上面的描述,我们发现GNN的表达能力的一大关键就是aggregate function的选择,我们最希望的就是对于不同的computation graph,我们aggregate出来的结果也是不同的,换言之,我们希望aggregate function是injective的。

先前我们所看到的各种aggregate function(sum、max、mean)其实都有一些failure cases,做不到injective。

在这里插入图片描述

因此GIN模型提出使用一层MLP来做aggregate,之所以使用MLP,是因为Universal Approximation Theorem:参数量达到一定程度的单层MLP可以拟合任意的函数。因此,GIN中的aggregate function就变成了如下形式:
M L P Φ ( ∑ x ∈ S M L P f ( x ) ) MLP_{\Phi}(\sum_{x \in S}MLP_{f}(x)) MLPΦ(xSMLPf(x))
同时,GIN利用了WL graph kernel的方法来实现message passing,并对结点embedding进行update,公式为:
h u ( k + 1 )   =   G I N C o n v ( h u ( k ) , { h v ( k ) , v ∈ N ( u ) } )                                            =   M L P Φ ( ( 1 + ϵ ) ⋅ M L P f ( h u ( k ) ) + ∑ v ∈ N ( u ) M L P f ( h v ( k ) ) ) h^{(k+1)}_{u}\ =\ GINConv(h^{(k)}_u, \{h^{(k)}_{v},v \in N(u)\})\ \\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ =\ MLP_{\Phi}((1+\epsilon) \cdot MLP_{f}(h_{u}^{(k)})+\sum_{v \in N(u)}MLP_{f}(h^{(k)}_{v})) hu(k+1) = GINConv(hu(k),{hv(k),vN(u)})                                         = MLPΦ((1+ϵ)MLPf(hu(k))+vN(u)MLPf(hv(k)))
这里的 k k k不再代表层数,而是代表WL test的迭代次数。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值