论文学习笔记:GraphSAGE


本文介绍一个inductive node embedding的框架——GraphSAGE。算法与其他transductive learning方法的优势在于,学习一系列能够聚合邻居节点特征,生成中心节点表示的aggregator。而不是只学习当前训练集的node embedding。所以模型可以很好的推广到新的graph上。算法在citation和Reddit数据集上取得了art of state的表现。并且在模型在完全没见过的图数据上表现出不错的泛化能力。

一、GraphSAGE

回顾原始 GCN,每层图卷积本质上是对中心节点的所有1阶邻居的特征做message passing,这些邻居节点的特征传递到中心节点之后,与中心节点的特征一起累计加权平均,再经过非线性变换,得到中心节点的表示。

GraphSAGE方法在此基础上做了三个方面的拓展。第一是引入采样技术,通过对中心节点的邻居均匀采样,得到一个包含中心节点和它的邻居节点的子图,对这个子图做卷积。采样技术的引入,一定程度上缓解了计算和内存压力,使模型可以在大规模图数据上训练;第二是抽象出一个 aggregate architectures,在卷积阶段聚合传递过来的邻居节点的特征。原始GCN中对passing message累计加权平均是aggregate architecture的一种形式。实际中可以根据需要选择不同的aggregate function。第三是卷积时concat中心节点本身的特征和aggregate的到的邻居节点表达,达到了在神经网络不同层之间‘skip-connection’的效果。

Screenshot from 2018-09-20 16:18:16.png-169.3kB

1.GraphSAGE forward propagation

对于给定的 G = ( V , ε ) \mathcal{G}=(\mathcal{V}, \varepsilon) G=(V,ε),graph中顶点的特征 { x v , ∀ v ∈ V } \{\mathbf{x}_v, \forall v \in \mathcal{V}\} {xv,vV}是模型的输入。假设在算法向前传播迭代次数为K,也就是神经网络的深度K,初始化每层的权重参数矩阵 W k , ∀ k ∈ { 1 , . . . , K } W^k, \forall k \in \{1,...,K\} Wk,k{1,...,K} 和aggregate function A G G R E G A T E k , ∀ k ∈ { 1 , . . . , K } AGGREGATE_k, \forall k \in \{1,...,K\} AGGREGATEk,k{1,...,K}中的参数。

image_1cnqtpkq5nfe1f04gft1f3f1ai7p.png-124.6kB

算法在每次迭代过程中,中心节点增量聚合邻居节点的信息。随着迭代的进行,获得越来越多距离更远处节点的信息。
在第k次迭代过程中,对图中的任意节点v,做了如下几步操作:

  • 1)根据邻接矩阵,找到中心节点v的邻居节点,均匀采样得到集合 N ( v ) \mathcal{N}(v) N(v)。GCN使用所有邻居节点,并不采样。这里的采样不仅可以降低计算复杂,还提高泛化能力。
  • 2)使用聚合函数 A G G R E A T E k AGGREATE_k AGGREATEk,聚合集合 N ( v ) \mathcal{N}(v) N(v)中节点的特征,得到 h N ( v ) k \mathbf{h}_{\mathcal{N}(v)}^k hN(v)k。聚合函数有多种形式,GCN中所用的方法是后面提到的均值聚合形式的变体。
  • 3) concat h N ( v ) k \mathbf{h}_{\mathcal{N}(v)}^k hN(v)k和节点v本身的特征,再经过非线性变换,得到 h v k \mathbf{h}_v^k hvk
  • 4)归一化得到本次迭代节点v的特征。

与 Weisfeiler-Lehman同构检测的关系 我们对算法做一些调整,将迭代次数设为 K = ∣ V ∣ K=|\mathcal{V}| K=V、假设权重矩阵独立、使用合适的hash函数作为aggregator, GraphSAGE 方法是 WL test的一个特例。如果GraphSAGE 在两个子图上的输出表示 { z v ,   ∀ v ∈ V } \{\mathbf{z}_v,\ \forall v \in \mathcal{V}\} {zv, vV} 相等,那么我们可以说这两个子图同构。当然这里GraphSAGE的目的是输出节点的表示,而不是同构测试。只是说明WL同构测试为算法提供了理论基础。

邻居节点的定义 算法中不使用不使用中心节点的所有1阶邻居。而是对这些邻居节点均匀采样一个固定大小的样本集合 N ( v ) \mathcal{N}(v) N(v)。注意每次迭代中样本量的大小也不固定。这样每个batch trainging的内存和时间复杂度从 O ( ∣ V ∣ ) \mathcal{O}(|\mathcal{V}|) O(V)下降到 O ( ∏ i = 1 K S i ) \mathcal{O}(\prod_{i=1}^K S_i) O(i=1KSi),其中   i ∈ { 1 , . . . , K } \ i \in \{1,...,K\}  i{1,...,K}。这样GraphSAGE在大规模图数据上也能很快的训练。

2. Learning the parameters of GraphSAGE

GraphSAGE 输出的graph节点的特征可以feed into各种下游学习任务。因此损失函数的设计与下游学习任务相关联。这里作者给出无监督学习graph node embedding的损失。损失函数倾向于连接紧密的节点的表达更相似,而连接分散的节点的表达差异很大。
J G ( z u ) = − log ⁡ ( σ ( z u T z v ) ) − Q ⋅ E v n ∼ P n ( v ) log ⁡ ( σ ( − z u T z v ) ) J_{\mathcal{G}}(\mathbf{z}_u)=-\log (\sigma (\mathbf{z}_u^T \mathbf{z}_v))-Q \cdot \mathbb{E}_{v_n\sim P_n(v)}\log (\sigma (-\mathbf{z}_u^T \mathbf{z}_v)) JG(zu)=log(σ(zuTzv))QEvnPn(v)log(σ(zuTzv))
其中,v是节点u的K阶邻域内的节点; σ \sigma σ是sigmoid函数; z u \mathcal{z}_u zu是节点u的表达; P n P_n Pn是负采样分布;Q是负采样的样本量。负采样的目的是,采集与节点u连接分散的节点。连接分散是指与节点u的距离大于K步。 − log ⁡ ( σ ( z u T z v ) ) -\log (\sigma (\mathbf{z}_u^T \mathbf{z}_v)) log(σ(zuTzv))表示最小化连接紧密的节点的表示之间的差异; − Q ⋅ E v n   P n ( v ) log ⁡ ( σ ( − z u T z v ) ) -Q \cdot \mathbb{E}_{v_n~P_n(v)}\log (\sigma (-\mathbf{z}_u^T \mathbf{z}_v)) QEvn Pn(v)log(σ(zuTzv))表示最小化连接分散的节点的表达之间的相似度。

3.Aggregate Architecture

均匀采样或者定长随机游走筛选出的节点v的邻居节点集合是无序的,aggregator function 操作的就是这些无序的特征向量,所以聚合函数除了可导、表达能力强之外,最好是对称的。本文测试了三个候选的聚合函数:

mean aggregator:对 { h u ( k − 1 ) , u ∈ N ( v ) } \{h_u^{(k-1)}, u \in N(v)\} {hu(k1),uN(v)}按元素取均值。这近似等价于GCN的向前传播法则。事实上,GCN可以算是均值聚合的归纳变体,只需将算法伪代码中第4,5行替换成如下式子
h v k ⟵ σ ( W ⋅ M E A N ( { h v k − 1 } ∪ { h u k − 1 , u ∈ N ( v ) } ) ) h_v^k \longleftarrow \sigma(W \cdot MEAN(\{h^{k-1}_v\} \cup \{h_u^{k-1}, u \in N(v)\})) hvkσ(WMEAN({hvk1}{huk1,uN(v)}))
它与mean aggregator的最大区别是,mean aggregator将节点v本身的特征和邻居特征结合的操作不是相加,而是concatenate。concatenation操作达到了GrapSAGE网络不同层之间‘skip-connection’的效果,对提高模型表达有非常重要的作用。

LSTM aggregator:基于LSTM结构的聚合函数有更强大的表达能力,但是不对称。因为它是以顺序的方式处理输入。但是这里作者还是运用LSTM来处理无序的邻居节点集合。

Pooling aggregator:pooling聚合函数是对称,可导。pooling中,每个邻居的向量独立的喂给一个全连接层。按元素max-pooling操作。
A G G R E G A T E k p o o l = max ⁡ ( { σ ( W p o o l h u i k + b ) , ∀ u i ∈ N ( v ) } ) AGGREGATE^{pool}_k=\max(\{\sigma(W_{pool}h^k_{u_i}+b), \forall u_i \in \mathcal{N}(v)\}) AGGREGATEkpool=max({σ(Wpoolhuik+b),uiN(v)})
max 表示按元素取最大值操作, σ \sigma σ是非线性激活函数。通常max pooling之前会接一个任意深度的多层感知机。多层感知机可以想象成一系列计算邻居节点特征的函数。在这些邻居节点的特征上运用max pooling能够有效捕获邻居节点集合的不同方面。

二、实验

作者在Web of Science citation dateset上paper分类、Reddit社区分类和蛋白质功能分类,三个任务上测试使用不同aggregatefunction 时GraphSAGE的表现,并设置4组对照实验。

实验显示,从算法表现和收敛速度上来看,GrapSAGE都远远优于之前的方法。并且GraphSAGE选择不同的聚合函数形式表现也会有些微的差异。使用不同的neighborhood sample size训练,得到收敛时间和算法表现,得到最佳的neighborhood sample size,当K=2时, S 1 = S 2 = 25 S_1=S_2=25 S1=S2=25。具体实验结果如下:

image_1cr50l1at1oi916ob9b91su51rat9.png-241kB

参考资料

https://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值