Representing Long-Range Context for Graph Neural Networks with Global Attention

abstract

图神经网络是结构化数据集的强大架构。然而,目前的方法很难代表长距离的依赖关系。缩放GNN的深度或宽度不足以扩大感受野,因为更大的GNN会遇到优化不稳定的问题,如梯度消失和再现过平滑,而基于集合的方法还没有变得像计算机视觉那样普遍有用。在这项工作中,我们提出使用基于Transformer的自我关注来学习长距离的成对关系,用一种新颖的 "读出 "机制来获得全局图的嵌入。最近的计算机视觉结果发现位置不变的注意力在学习长距离关系方面表现良好,受此启发,我们的方法,我们称之为GraphTrans,在标准的GNN模块之后应用了一个变异的Transformer模块。这种简单的结构导致了在几个图分类任务上的最先进的结果,超过了明确编码图结构的方法。我们的结果表明,没有图结构的纯学习方法可能适合学习图上的高层次、长距离关系。

introduction

图形神经网络(GNNs)使深度网络能够处理结构化的输入,如分子或社会网络。图神经网络学习映射,从其邻域的结构和特征中计算图节点和/或边缘的表示。这种邻域-局部聚合利用了图的连接性所编码的关系归纳偏向[3]。与卷积神经网络(CNN)类似,GNN可以通过堆叠层聚合来自本地邻域以外的信息,有效地扩大GNN的接受域。

然而,当GNN的深度增加时,其性能急剧下降[21]。这种限制损害了GNN在全图分类和回归任务中的性能,在这些任务中,我们想要预测描述整个图的目标值,而这些目标值可能依赖于长距离的依赖关系,而这些依赖关系可能不会被具有有限感受野的GNN所捕获[35]。例如,考虑到一个大型图,其中节点A必须关注K跳之外的遥远节点B。如果我们的GNN层只在一个节点的一跳邻域上聚集,那么就需要一个K层的GNN。然而,这个GNN的接收场的宽度将呈指数增长,稀释了来自节点B的信号。也就是说,简单地将接收场扩大到K跳邻域可能也不能捕捉到这些长距离的依赖性[40]。通常,"太深 "的GNN会导致节点表征在整个图上坍缩成等值,这种现象有时被称为过度平滑或过度量化[21, 5, 2]。因此,常见的GNN架构的最大上下文大小实际上是有限的。

在这项工作中,我们在GNN的图池和学习长距离依赖方面采取了不同的方法。与分层汇集一样,我们的方法也受到了计算机视觉方法的启发:我们将一些明确编码相关关系归纳偏见的原子操作(即CNN中的卷积或空间汇集,GNN中的邻域粗化)替换为纯粹的学习操作,如注意力[11, 4, 7]。

我们的方法,我们称之为Graph Transformer(GraphTrans,见图1),在标准的GNN层堆栈之上添加一个Transformer子网络。这个Transformer子网络以一种位置无关的方式明确地计算所有成对节点的相互作用。这种方法是直观的,因为它保留了GNN作为一个专门的架构来学习节点近邻结构的局部表示,同时利用Transformer作为一个强大的全局推理模块。这与最近的计算机视觉架构相似,作者发现硬关系归纳偏见对学习短距离模式很重要,但在建模长距离依赖关系时却不那么有用,甚至起反作用[25]。由于没有位置编码的Transformer是不变的,我们发现它很自然地适用于图。此外,GraphTrans不需要任何专门的模块或架构,可以在任何现有的GNN骨干上的任何框架中实现。

我们在各种流行的图分类数据集上评估GraphTrans。我们发现在OpenGraphBenchmark[15]上的准确率有了明显的提高,我们在两个图分类任务上取得了最先进的结果。此外,我们发现在分子数据集NCI1上也有很大的改进。令人惊讶的是,我们发现我们的简单模型优于复杂的基线,通过分层聚类,如自我注意池[20],在图中进行长程建模。

related work

Graph Classification.

图分类是现实世界应用中的一项重要任务。尽管GNN将结构化数据编码为节点表征,但将表征聚合为一个单一的图嵌入用于图分类仍然是一个问题。与CNN类似,GNN中的集合可以是全局的,将一组节点和/或边的编码减少到一个单一的图形编码,或者是局部的,折叠节点和/或边的子集以创建一个更粗糙的图形。与CNN中的中间池的使用相类似,一些作者提出了局部池操作,旨在用于GNN层堆栈中,逐步粗化图形。提出的方法包括学习的池化方案[37, 20, 14, 16, 1等]和基于经典图粗化方案的非学习的池化方法[10, 9等]。然而,GNN中分层的、基于粗化的池化的有效性或必要性还不清楚[23]。另一方面,最常见的全局、全图池化方法是:i)非学习平均或最大节点池化;ii)"虚拟节点 "方法,即GNN的最后一层输出一个与图中每个 "真实 "节点相连的单一虚拟节点的嵌入。

与图池有关的一个值得注意的工作是Thost和Chen的DAGNN(有向无环图神经网络)[27],它在OGBG-Code2上获得了以前最先进的准确性。DAGNN层通过一个遍历DAG的RNN在每层内聚集整个图,而不像大多数GNN层只聚集在一个节点的邻域。虽然他们没有将这种方法定性为池化操作,但它与GraphTrans相似,因为它作为一个学习的全局池化(在它将DAG中每个节点的嵌入聚合到汇节点中),可以模拟长距离的依赖关系。请注意,GraphTrans也是对DAGNN的补充,因为他们最终的图级池化操作是对汇节点的全局最大池化,而不是学习操作。

Transformers on Graphs

一些作者已经研究了Transformer架构在图中的应用。最近的作品如Zhang等人[38]、Rong等人[24]以及Dwivedi和Bresson[12]提出了GNN层,让节点通过Transformer式的注意来关注周围一些邻居中的其他节点,而我们将自我注意用于包络不变的、图层面的汇集或 "读出 "操作,将节点编码折叠成单一的图编码。其中,Zhang等人[38]和Rong等人[24]通过允许节点关注的不仅仅是一跳邻域,解决了学习长距离依赖关系而不过度平滑的问题。Zhang等人[38]将关注的邻域半径作为一个调整参数,Rong等人[24]在训练和推理期间关注随机大小的邻域。相比之下,我们使用全图的自我关注来允许学习长距离的依赖关系。

虽然Zhang等人[38]没有考虑全图预测问题,但在Dwivedi和Bresson[12]的案例中,当需要对图分类或回归进行全图嵌入时,他们在节点上使用全局平均池,而Rong等人[24]在节点上采取加权和,计算的权重绕过hLv 's到两层MLP。还需要注意的是,之前的工作考虑了Transformer的位置编码的特定图形版本,而我们省略了位置编码以确保包络不变性。

Efficient Transformers.

变换器[28]已被广泛用于序列建模。最近,出现了对变换器结构的修改以进一步提高效率[34, 19, 6]。LiteTransformer[34]的FLOPs较少,Reformer[19]的复杂度较高,Performer[6]的计算和内存复杂度都较低。神经架构搜索(NAS)也被应用于Transformer,以满足边缘设备的资源限制[32]。这些现成的架构与我们的GraphTrans是正交的,可以采用这些架构来提高扩展性。

Motivation: Modeling Long-Range Pairwise Interactions

总而言之,试图通过堆叠GNN层或分层池在图上进行长程学习还没有带来性能的提高,虽然有些工作在将单个GNN层的接受域扩展到一跳邻域之外方面取得了一些成功[38, 24, 40],但这种方法如何扩展到有数千个节点的非常大的图上还有待观察。

在最近的计算机视觉文献中可以找到另一种方法的灵感。在过去的几年中,研究人员发现注意力机制可以作为传统CNN卷积的替代品[4, 7]:注意力层可以学习复制由局部卷积引起的强关系归纳偏见。最近,一些计算机视觉任务的最先进的方法在传统的CNN主干上使用了一个注意力式的子模块[2, 33, 等]。这些结果表明,虽然强关系归纳偏见有助于学习局部的、短距离的关联,但对于长距离的关联,结构较少的模块可能是首选[2]。

我们用GraphTrans模型将这一见解运用到图学习领域,该模型使用传统的GNN子网络作为骨干,但将学习长距离的依赖关系留给没有图空间预设的Transformer子网络。如前所述,我们的Transformer应用允许每个节点关注其他每个节点(与其他只允许关注邻域的Transformer应用方法不同),这激励Transformer学习最重要的节点-节点关系,而不是偏爱附近的节点(后者的任务已经被卸载到前面的GNN模块)。

从质量上看,这个方案提供了证据,证明长距离的关系确实很重要。图2描述了GraphTrans在OGB Code2数据集上的一个应用实例。在这个任务中,我们接收了通过解析Python方法得到的抽象句子树,并需要预测构成方法名称的标记。注意力图表现出与变形金刚的NLP应用中类似的模式:一些节点从许多其他节点获得了显著的权重,而不管它们之间的距离如何。请注意,尽管这两个节点相距五跳,但17号节点对8号节点给予了显著的重要性。另外,在图2的注意力地图中,索引18指的是与我们用作读出机制的特殊标记相对应的嵌入,下面将详细介绍。我们允许这个嵌入是可学习的,所以许多关注它的节点(由第18列中的许多深色单元表示)可能表明这些节点正在从学习的嵌入中获得一些图形的一般记忆。这种定性的可视化,加上我们新的最先进的结果,表明在学习长距离的依赖关系时,去除空间预设可能是有效的图总结所必需的。

Learning Global Information with GraphTrans

参照图1,GraphTrans由两个主要模块组成:一个GNN子网络然后是一个Transformer子网络。接下来我们将详细讨论这些模块。
请添加图片描述
请添加图片描述
请添加图片描述

embedding as a GNN “readout” method.

如前所述,对于全图分类,我们需要一个描述全图的单一嵌入向量。在GNN文献中,这种将每个节点和/或边缘的嵌入折叠成单一嵌入的模块被称为 "读出 "模块,最常见的读出模块是简单的均值或最大值集合,或者是与网络中每个其他节点相连的单一 “虚拟节点”

在这项工作中,我们提出了一个特殊标记的读出模块,类似于Transformers.的其他应用中使用的模块。在使用Transformers.的文本分类任务中,一个常见的做法是在输入序列中附加一个特殊的标记,然后把这个标记的位置对应的输出嵌入作为整个句子的表示。这样,转化器将被训练为将句子的信息汇总到该嵌入中,通过计算标记和句子中的其他标记之间的一对一关系,关注模块`

请添加图片描述
这种特殊令牌的读出机制可以被看作是对虚拟节点读出的概括或 "深度 "版本。虚拟节点读出的概括或 "深度 "版本。虚拟节点方法要求图中的每个节点将其信息发送给虚拟节点,并且不允许学习图中节点之间的成对关系,除非是在虚拟节点的嵌入中(可能会产生信息瓶颈),而Transformer式的特殊令牌读出方法让网络在早期层中学习长距离的节点与节点之间的关系,然后才需要在后期层中进行提炼。

Experiments

我们在生物学、计算机编程和化学这三种模式的图分类任务上评估GraphTrans。我们的GraphTrans在所有这些基准测试中都取得了一致的改进,表明了该框架的通用性和有效性。我们所有的模型都是用Adam优化器[17]训练的,学习率为0.0001,权重衰减为0.0001,并使用默认的Adam β参数。我们实验中使用的所有Transformer模块的嵌入维度dTF为128,前馈子网络的隐藏维度为512。下面描述的Transformer基线只用节点嵌入序列来训练,放弃了图结构

Biological benchmarks

我们选择了两个常用的图分类基准,NCI1和NCI109[31]。它们中的每一个都包含约4000个图,平均约有30个节点,代表生化化合物。我们的任务是预测一个化合物是否含有抗肺癌活性。我们按照文献[20,2]中对NCI1和NCI109的设置,将数据集按8:1:1的比例随机分成训练集、验证集和测试集。

Training Setups.

实验中所有的GNN模块都遵循OGB中提供的默认GIN模型的设置,有4层和300个隐藏维度。我们对所有的模型进行了100次历时训练,批次大小为256,并以最佳验证ROC-AUC报告测试结果。对于GNN和Transformer模块,我们采用了0.3的drouout。我们使用GIN作为基线和GNN模块,因为它在Molpcba数据集上的表现比GCN模型更好。
请添加图片描述我们在表1中报告了NCI1和NCI109的结果。简单基线,包括GCN Set2Set、SortPool和SAGPool,取自[20],而强基线[13]以及FA层[2]。在表1中,我们的Graph Transformer(小)的架构与简单基线相同,但对NCI1和NCI109的平均准确率分别提高了7.1%和5.1%。我们还测试了用GIN作为编码器的框架(GraphTrans (large)),以与强基线中的设置保持一致,这也明显提高了强基线的准确性,对NCI1提高了1.1%,对NCI109提高了8.2%,即使没有深度GNN,使用4层而不是8层。

Chemical benchmarks

对于化学基准,我们在一个比NCI数据集更大的数据集上评估我们的GraphTrans,即来自Open Graph Benchmark(OGB)的molpcba[15]。它包含437929个图,平均有28个节点。数据集中的每个图代表一个分子,其中节点和边分别是原子和化学键。任务是预测一个分子的多种属性。我们使用基准的标准划分。GIN和GIN-Virtual基线的性能与OGB排行榜[15]上的报告一致.

Training Setups.

实验中所有的GNN模块都遵循OGB中提供的默认GIN模型的设置,有4层和300个隐藏维度。我们对所有的模型进行了100次历时训练,批次大小为256,并以最佳验证ROC-AUC报告测试结果。对于GNN和Transformer模块,我们采用了0.3的droupout。我们使用GIN作为基线和GNN模块,因为它在Molpcba数据集上的表现比GCN模型更好。

请添加图片描述
在表2中,我们报告了Molpcba的验证和测试集上的ROC-AUC。尽管Transformer在这个数据集上的效果很差,但我们的GraphTrans仍然改善了GIN和GIN-Virtual基线的ROC-AUC。这表明我们的设计可以从GNN学习的局部图结构和Transformer模块基于GNN嵌入检索的远距离概念中获益。

Computer programming benchmark

对于计算机编程基准,我们还采用了一个大型数据集,即来自OGB的code2,它有45741个图,每个图平均有125个节点。该数据集是一个抽象语法树(AST)的集合,来自大约450k Python方法定义。我们的任务是预测形成方法名称的子标记,给定由AST代表的方法主体。我们还采用了基准的标准数据集分割。所有的基线性能都是在OGB排行榜上报告的。

Training Setups.

对于计算机编程基准,我们还采用了一个大型数据集,即来自OGB的code2,它有45741个图,每个图平均有125个节点。该数据集是一个抽象语法树(AST)的集合,来自大约450k Python方法定义。我们的任务是预测形成方法名称的子标记,给定由AST代表的方法主体。我们还采用了基准的标准数据集分割。所有的基线性能都是在OGB排行榜上报告的。

在表3中,我们将我们的GraphTrans与Code2数据集排行榜上的顶级架构进行比较。随着每个图中平均节点数的增加,全局信息变得更加重要,因为GNN从远处的节点收集信息变得更加困难。即使不进行大量的调整,GraphTrans在排行榜上也明显优于最先进的(DAGNN)[27]。我们还包括PNA模型和我们的GraphTrans的结果,PNA模型作为GNN编码器。我们的GraphTrans也明显改善了结果,这表明我们的架构与GNN编码器模块的变体是正交的。请添加图片描述

Transformers can capture long-range relationships

正如我们之前在图2中观察到的和在第3节中讨论的, transformer模块内部的注意力可以捕捉到GNN模块难以学习到的长距离信息。
为了进一步验证这一假设,我们设计了一个实验来证明转化器模块可以学习到GNN模块的额外信息。在表4中,我们首先预训练一个GNN(GCN-Virtual),直到在Code2数据集上收敛,然后冻结GNN模型并在它之后插入我们的Transformer模块。通过在训练集上用固定的GNN模块训练模型,我们仍然可以观察到在验证集上有0.0022个F1分数的改进,在测试集上有0.0042个。这表明Transformer可以学习到GNN模块难以学到的额外信息。
通过预训练和解冻的GNN模块,我们的GraphTrans可以获得更高的F1分数。这可能是因为GNN模块现在可以专注于学习局部结构信息,而将长程信息的学习留给后面的Transformer层。该模型得益于[34]中提到的专门化。请注意,在表4的所有实验中,为了简单起见,我们没有将输入图的嵌入连接到Transformer的输入。
请添加图片描述

Effectiveness of embedding

在图2b中,我们可以观察到第18行(最后一行为)在多列上有暗红色,这表明学会了关注图中的重要节点,以学习整个图的表示。请添加图片描述
我们还定量地检查了我们的嵌入的有效性。在表5中,我们测试了几种常用的方法来进行序列分类。平均操作将transformer的输出嵌入平均到一个图形嵌入;最后操作将输出序列中的最后一个嵌入作为图形嵌入。定量结果表明,嵌入是最有效的,在测试集上有0.0275的改进,因为该模型可以学习从不同的节点检索信息,并将它们聚合成一个嵌入。输入图中的嵌入和转化器的输入嵌入的串联可以进一步提高验证和测试的F1分数,达到0.1670和0.1733。

Scalability

为了对GraphTrans在100个节点以上的大型图上的扩展性进行定量基准测试,我们在不同的图大小和边缘度下进行了训练的迭代时间的微观基准。我们在随机生成的具有不同节点数和边缘密度的Erdos-Renyi图上训练基线。如表6所示,当节点数和边缘密度增加时,我们的GraphTrans模型的扩展性至少与GCN模型一样好。GCN和GraphTrans在大型密集图中都出现了内存错误(OOM),但我们注意到GraphTrans的内存消耗与GCN的基线相似。

请添加图片描述

Computational efficiency请添加图片描述

为了评估我们的GraphTrans在特定GNN骨干上增加的开销,我们评估了每次迭代的前向传递运行时间和后向传递运行时间。我们将模型规范化,使其具有大致相似的参数数量。结果显示在表7中。对于NCI1数据集,GraphTrans的训练速度实际上比可比的GCN模型快。对于OGB-molpcba和OGB-Code2数据集,GraphTrans比基线GNN架构慢了7-11%。

Number of parameters请添加图片描述

我们在表8中比较了GNN基线和GraphTrans在不同数据集上的参数数量。总的来说,GraphTrans只在Molpcba和NCI上略微增加了总参数。对于Code2,GraphTrans的参数效率大大高于GNN,同时将测试F1得分从0.1629提高到0.1810。参数效率提高的一个原因是Transformer在昂贵的最终预测层之前减少了特征尺寸。

Conclusion

我们提出了GraphTrans,一个简单而强大的框架,用于学习GNN的长距离关系。最近的研究结果表明,对于高层次的长距离关系来说,结构性先验因素可能是不必要的,甚至是适得其反的,利用这些结果,我们用一个后续的互变不变的Transformer模块增强了标准的GNN层堆栈。变换器模块作为一个新的GNN "读出 "模块,同时允许学习图节点之间的成对互动,并将它们总结为一个特殊的标记嵌入,就像变形器的常见NLP应用中所做的那样。这个简单的框架导致了在程序分析、分子和蛋白质关联网络等几个图分类任务中,对技术水平的惊人改进。在某些情况下,GraphTrans优于那些试图编码特定领域结构信息的方法。总的来说,GraphTrans提出了一种简单而普遍的方法来改善长距离的图分类;接下来的方向包括应用于节点和边的分类任务,以及进一步提高转化器对大型图的可扩展性。

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值