雪花假设:训练deep GNN 新思路

本文由中科大数据智能实验室联合阿卜杜拉国王科技大学(KSUST)、同济大学、新加坡国立大学、深圳科技大学、香港科技大学等多家院校联合提出了一种全新的训练深度GNN的策略—雪花假设,旨在帮助未来训练深度GNN并克服其中可能出现的over-smoothing/over-fitting等问题。

d596d6fe5b11c74e4b36e416778c6b63.png

论文标题:The Snowflake Hypothesis: Training Deep GNN with One Node One Receptive field

论文地址:https://arxiv.org/abs/2308.10051

背景

过拟合,过平滑和梯度消失是GNN领域的三个长期存在的问题,特别是当GNN仿照卷积神经网络加深网络时。因此,当在小型图上训练过参数化的GNN或使用深度GNN进行图建模时,我们通常最终会得到塌缩的权重或不可区分的节点表示。因此,训练2-4层的GNN在图领域一直是一个比较普遍的现象,大多数最先进的GNN一般也不会超过不超过4层。然而,仔细研究许多计算机视觉任务和自然语言处理上的成就,可以很好的归功于深度网络的持续有效训练。因此,图表示学习迫切需要利用更深层的图神经网络,特别是在处理以密集连接为特征的大规模图时。

最近,一些工作表明了随着深度的增加训练GNN的可行性。我们可以将现有方法总结为两类。第一类涉及继承CV领域的技术,如Res/Skip-connection,这些方法已被证明是普遍适用和实用的。例如,JKNet采用跳接方式对各层的输出进行融合,以保持不同节点之间的差异。GCNII和ResGCN采用残差连接来携带来自前一层的信息,以避免上述问题。另一类是将各种深度聚合策略与浅层神经网络相结合。例如,GDC将个性化PageRank推广为图扩散过程。DropEdge借助随机的边丢弃策略来隐式地增加图的多样性并减少消息传递。

然而,尽管CNN的残差/跳跃连接等继承机制可以部分缓解过平滑问题,但这些改进未能有效探索聚合策略与网络深度之间的关系。将残差合并到具有次优输出的层中可能无意中将有害信息传播到后续的聚合层。在第二类中,大多数现有的深度聚合策略试图对中心节点周围的邻近节点进行抽样,以隐式地增强数据多样性并防止过度平滑。不幸的是,繁琐和特殊的设计使GNN模型既不简单也不实用,缺乏在其他训练策略和特定数据集上扩展的能力。

思路:

基于上述观测,本文首次提出训练GNN时让每个节点都具备自身的感受野,通过node receptive field比喻成雪花,来反应每个节点独一无二的特性(https://zhuanlan.zhihu.com/p/100948902),基于大量实验,本文提出了雪花假设

4681da2004d95a701b5137bdd5150b74.png

为了更好的发现独一无二的“雪花”,本文提出了两种策略:SnoHv1通过判断邻接矩阵的梯度,将每一行的梯度求和,并找出梯度最小的行,来对邻接矩阵进行layer-wise element pruning(注意这儿不删除自环中的对角线元素),layer-wise element pruning可以很好的保证某些节点在聚合深度上实现“early stopping”,使得某些节点只对外输出信息,而对内聚合的通道消失。算法思路如下:

155822f148066cbf230fb07363d5ffcf.png

然而,SnoHv1在大图上的拓展能力稍微逊色,因为要判断百万元素的梯度,并进行求和,这会导致训练速度十分低效。为此,本文继续给出了效率更高的SnoHv2:

cb76d5227fb09ab7cc46561113f161e4.png

通过判断深层和初始层的余弦距离,SnoHv2的想法十分简单,当深度加深时,过平滑问题出现后,节点表示会趋同,相比于第一层的余弦距离,后层的余弦距离会不断减小,当小于第一层的某个百分比之后,我们就对该节点进行layer-wise element pruning。进而更好的帮助每个节点实现深度上的“early stopping”。

实验:

本文进行了大量的实验,(1)涵盖不同的训练策略,如迭代剪枝、充分预训练、重初始化。(2)与目前的深度GNN架构进行结合,如ResGCN,ResGCN+,JKNet,PairNorm等;(3)切换浅层GNN架构,GIN,GAT等;(4)与目前主流的聚合策略相比,如DropEdge。

6a493bf2b15b86e1d37a01242bbe4ea2.png dae6941f70a515c0cce929e1d8cbb9e1.png bba93e19e7ea4dea01ec55b19c0c50c0.png

由于我们的算法是基于剪枝实现的,我们还比较了目前主流的剪枝框架:

34eece09b0101a1686f04394c8df1295.png

我们的进行了六个数据集的实验,包含小数据集以及千万边级别的大数据集,我们的结果表明,在深度GNN中确实有很多节点进行深度上的早停,不会影响模型性能,甚至可以出现较为明显的提升,证明了我们框架的有效性。

结语:

本文提出了一种全新的假设,认为每个节点在深度场景下都应该具备自己独立的感受野,具有内在的可解释性,在继承剪枝算法优势(加速推理时间和减少存储开销)的同时,也可以使得当前的图剪枝算法从中收益。更重要的是,该算法简单方便。与复杂聚合策略的设计相比,该框架没有引入任何额外的信息(如可学习的参数),可以很容易地扩展到深度GNN。我们进行了全面的实验,跨越了一系列训练算法,与各种骨干架构的集成,并在多个图基准上与DropEdge/UGS框架进行了比较。研究结果表明,SnoHv1/v2始终能提供出色的性能,即使在邻接矩阵明显稀疏的情况下。这些结果强调了我们最初的假设:某些节点需要在其深度进展中提前终止。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值