【论文解读】ICLR 2021丨当梯度提升遇到图神经网络,“鱼和熊掌”皆可兼得

论文:https://arxiv.org/pdf/2101.08543.pdf

代码:https://github.com/nd7141/bgnn

无论是分子设计、计算机视觉,还是组合优化和推荐系统等,图神经网络( GNNs,Graph Neural Networks )都在学习图结构化数据方面取得了巨大的成功。

这种成功的主要驱动力之一在于 GNN 能够有效地将原始输入数据编码成表达性表示,以便在新的数据集和任务上实现高质量的结果。

近期,关于 GNN 的研究主要集中在具有稀疏数据的 GNNs 上。这些数据代表同构节点嵌入(例如,独热编码图统计)或词袋表示(bag-of-words representations)。然而,图(graph)中的表数据(Tabular Data)节点之间,包含有详细的信息和大量的语义表示。

以社交网络为例,每个人都有社会人口统计学特征(如年龄、性别和毕业日期等)。显然,这些特征在数据类型、规模和缺失值上存在很大差异。对于带有表格数据的图,GNNs 仍未迈出进一步探索的脚步。另一方面,梯度提升决策树(GBDT, Gradient Boosted Decision Trees)在具有此类异构数据的应用程序中占据着主导地位。

而在 ICLR 2021 文章 Boost Then Convolve: Gradient Boosting Meets Graph Neural Networks 中,作者 Sergei Ivanov 、 Liudmila Prokhorenkova 提出了一种新颖的架构,可以联合训练 GBDT 和 GNN 以获得两者的优点:GBDT 模型处理异构特征,而 GNN 负责图结构。

各自长短

首先,简单介绍下 GBDTs 在表格数据上为何会如此成功。这主要得益于其以下特性:

(1)它可以有效地学习在表格数据中常见的具有类超平面边界的决策空间;

(2)它非常适合处理高基数的变量以及值缺失且比例不同的特征;

  (3)通过事后分析阶段为决策树或集成提供定性解释;

(4)在实际应用中,即便是在处理大量数据时也会快速收敛。

相比之下,GNNs 的关键特征则为,同时考虑节点的邻域信息和节点特征来进行预测,这区别于 GBDTs 需要额外的预处理分析来为算法提供图摘要(graph summary)。

而且理论上已经证明,消息传递 GNNs 可以在图灵机可计算的图输入上计算任何函数,即,GNNs 是已知的唯一在图上具有通用性的学习体系结构。

除此之外,与基于树的方法相比,基于梯度的神经网络学习具有更大的优势:

(1)强加于 GNNs 的关系归纳偏置(relational inductive bias)减少了对捕获网络拓扑结构的特性进行手动工程的需要。

(2)训练神经网络的端到端性质允许在依赖于应用程序的解决方案中将 GNNs 进行多分段或多组件集成。

(3)图网络的预训练表示丰富了许多重要任务的迁移学习,如无监督领域适应(UDA, unsupervised domain adaptation)、自监督学习(self-supervised learning)和主动学习机制(active learning)。

图神经网络已在各种图表示学习任务中取得了成功。面对异构表格数据时,GBDTs 通常优于其他机器学习方法。但是,对于具有表格节点特征的图,究竟该选择哪种方法?先前的 GNNs 模型主要集中在具有同构稀疏特征的网络上,而在异构环境中则表现次优。

毫无疑问,GBDTs 和 GNNs 方法都在各自的领域有着核心的竞争力。

因此,论文作者们不禁猜想:是否能充分使用这两者间的潜力?

之前曾出现尝试结合梯度增强和神经网络的方法,但鉴于计算成本高,没有将图结构数据考虑在内,以及缺乏 GNNs 体系架构中强加的关系归纳偏置等多方面因素,并未取得预期效果。而据了解,本文所介绍的任务,是首次系统地使用 GBDTs 模型进行图结构数据探索的工作。

如何结合双方优势?

在这项工作中,研究团队提出了一种新颖的体系结构 BGNN(Boost-GNN)—— 将 GBDTs 对表格节点特征的学习与 GNNs 相结合,联合训练 GBDTs 和 GNNs 以有效获得两者的最佳效果,即 GBDTs 模型处理异构特征,而 GNNs 用于解释图结构,两者结合以优化预测结果。

这使得研究团队提出的模型受益于端到端优化,允许新的树适应 GNNs 的梯度更新。通过与前沿的 GBDTs 和 GNNs 模型进行广泛的实验比较,团队充分证实了具有表格特征的各种图的性能均得到了显著提高。

具体而言,设置 G=(V, E) 为具有特征和目标标签的节点的图。在节点预测任务 (分类或回归) 中,可以借助已知的目标标签预测未知。‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍‍在整个文本中,用小写变量   (   )或   表示单个节点的特征,而   则为所有特征的矩阵表示。单个目标标签用   表示,而   和   分别代表所有的和仅为训练标签的向量。

图神经网络 (GNNs) 利用网络的连通性和节点特征来学习所有节点   的潜在表示。很多流行的 GNNs 都采用邻域聚合方法,也称为消息传递机制,通过应用其邻居表示的非线性聚合函数以更新节点   的表示。类似于传统的神经网络,GNNs 由多层组成,每个层代表一个非线性消息传递函数:

   

其中,   表示第   层上节点   的表示,   和   是聚合来自节点本地邻域表示的函数。通过最小化经验损失函数,采用梯度下降法来优化 GNN 模型的参数。

而梯度提升决策树 (GBDTs) 则是一种迭代的决策树算法,由多棵决策树组成,所有树的结论累加起来作为最终答案。此算法是在非图形表格数据上定义的广泛应用的算法,尤其适用于包含异构特性和噪声数据的任务。

梯度提升的核心思想是通过迭代添加弱模型 (往往选择决策树) 来构建强模型。形式上,在梯度提升算法的每次迭代   中,模型 f (x) 都以累加的方式进行更新:

  ε  

其中,   是前一次迭代中构建的模型,   是从某些函数族中选择的弱学习者,  ε  表示的是学习率。选择   来近似损失函数 L w.r.t 的负梯度。

弱学习者的集合通常由浅层的决策树构成。决策树是通过将特征空间递归划分成称为叶子的不相交区域来建立的。

当 GBDT 遇到 GNN

梯度提升方法在表格数据学习中效果显著;然而,将 GBDT 应用于图结构数据存在以下挑战:

(1)如何将节点特征之外的关系信号传播到本来固有的表格模型;(2)如何以端到端的方式将其与 GNNs 一起训练。

的确,GBDT 和 GNNs 的优化遵循着不同的方法:GNNs 的参数主要通过梯度下降进行优化,而 GBDT 是迭代构造的。

此时,应该采用什么解决方案?

一种简单的方法是仅在节点特征上训练 GBDT 模型,然后将获得的 GBDT 预测结果与原始输入一起作为 GNNs 的新节点特征。在这种情况下,将通过图神经网络进一步完善 GBDT 对图不敏感的预测问题。作者将这种方法称之为 Res-GNN,显然,此方法已经可以提高 GNNs 在某些任务上的性能。但是,在这种情况下,GBDT 模型完全忽略了图结构,可能会遗漏图的描述性特征,导致为 GNNs 提供的输入数据在准确性上存在偏差。

图 1

相反地,研究团队建议对 GBDT 和 GNNs 进行端到端的训练,简称为 BGNN。与先前类似的是,首先应用 GBDT,然后采用 GNNs。但考虑到最终预测的质量,团队成员对两者进行了优化。BGNN 的训练如图 1 所示。因为建立好的决策树结构离散而无法进行适当调整,因此团队成员通过添加新的树来迭代更新 GBDT 模型,使其近似于 GNNs 的损失函数。

算法 1

在算法 1 中,团队展示了结合 GBDT 和 GNNs 的 BGNN 模型的训练,以解决任何节点级预测问题,如半监督节点回归或分类。在首次迭代中,研究团队通过最小化在训练节点上的平均损失函数  ),  ,来建立带有 k 棵决策树的 GBDT 模型。借助所有预测   ,将节点特征更新为   ,然后传递给 GNNs。

整体来看,BGNN 模型的主体架构只由两个连续的块 (GBDT 和 GNNs) 组成,并且执行的是端到端的训练,因此可以从两个角度来阐述两者间的关系:GBDT 是 GNNs 的嵌入层,或者 GNNs 是 GBDT 的参数损失函数。

更具体地,在第一种情况下,GBDT 将原始输入特征   转换为新的节点特征   ,然后将其传递给 GNNs。而第二种情况下,可以将 BGNN 视为标准的梯度提升训练,其中 GNNs 则充当依赖于图拓扑的复杂损失函数。

实验及结果分析

团队成员对 BGNN 和 Res-GNN 进行了比较评估,对比了各种强基线和先前关于异构节点预测问题的方法,从而在所有方面均实现了显著的性能提升。为了确保实验的公平性,在训练每个模型时,保证验证集上的超参达到收敛状态,并根据三次超参设置的运行结果,取均值。在这部分中,主要概述了三方面的内容:实验设置,异构节点回归问题的结果以及提取的特征表示。

表 1

针对于异构节点的回归问题,研究团队使用了五个具有不同属性的真实世界的节点回归数据集,具体统计情况如表 1 所示。其中,四个数据集(House、County、VK、Avazu)是异构的,也就是说输入要素之间彼此独立,很可能具有不同的类型、规模和含义。而剩下的 Wiki 数据集是同构数据集,节点之间是相互依存的,并且对应于维基百科的词袋表示。

表 2

表 2 则给出了各模型间对节点回归的比较评估结果。从表中的报告数据可以明确的得出,团队提出的 BGNN 模型相比于基线有了显著的提升。特别是,在异构的情况下,BGNN 分别使 House、County、VK、Avazu 数据集的错误率减少了 8%、14%、4% 和 4%。使用预训练的 CatBoost 模型作为 GNNs 输入的 Res-GNN 模型也降低了均方根误差值(RMSE, Root Mean Squared Error),但不如端到端模型 BGNN 降低得多。在同构数据集 Wiki 中,相比于 CatBoost 模型以及随后的 Res-GNN 和 BGNN 模型,GNN 模型的表现更好一些。

直观地说,在特征同构的情况下,神经网络方法就足以获得最佳结果。这也潜在表明,BGNN 可以获得更好的定性结果,其端到端的训练方式在表格数据图的节点预测任务中占据明显的优势。除此之外,端到端的组合方法 FCNN-GNN 比单纯使用 GNNs 所获得的性能更好。但是,与融合 GBDT 优势的 BGNN 模型相比,其改进幅度仍然偏小。

需要注意的是,CatBoost、LightGBM 和 FCNN 可以单独发挥作用,但它们的性能在所有数据集上的显示并不稳定。总体而言,这些对照实验有力的证明了 BGNN 模型相对于其他强模型的优越性。

对于节点分类而言,研究团队使用了五个具有不同属性的数据集。由于缺乏具有异构节点特征的公开数据集,团队通过将目标标签转换为若干离散类,采用回归任务中的数据集 House 类和 VK 类。另外,还引入了两个来自异构信息网络(HIN, Heterogeneous Information Networks)的稀疏节点分类数据集 SLAP 和 DBLP,其中节点包含有几种不同的类型。为了完整起见,在实验过程中,团队还加入了一个同构数据集 OGB-ArXiv。

团队成员留意到,该数据集中的节点特征对应于 128 维的特征向量,该向量是通过平均标题和摘要中的词嵌入得到的。由此可见,该数据集的特征并非异构,因此推断,与神经网络方法相比,GBDT 并不会很高。

表 3

从表 3 中可以看出,BGNN 方法在具有表格特征的数据集(House 类和 VK 类)上的结果存在显著优势。例如,对于 VK 类数据集,BGNN 在精确度上实现了 18% 以上的增加量。这表明,GBDT 和 GNNs 的联合学习表示形式在具有异构特征数据的节点分类设置时同样有用。

其他的两个数据集 Slap 和 DBLP 具有稀疏的词袋特征,这对 GNNs 模型来说是个十足的挑战。在这两个数据集中,GBDT 作为最强的基线。而且,由于 FCNN 优于 GNN,团队初步猜测,图结构可能无济于事,也就是说在最终呈现的实验结果中,BGNN 不应该优于 GBDT。

当然,事实确实如此:BGNN 的最终精确度略低于 GBDT。

在同构数据集 OGB-ArXiv 中,FCNN-GNN 和 GNN 模型性能最高,其次是 Res-GNN 和 BGNN 模型。简而言之,GBDT 无法对同构输入特征进行良好的预测,因此降低了 GNN 的判别能力。在数据集具有稀疏性和同构特征的两种情况下,均表明 BGNN 的性能与 GNN 相当或更高。但是,由于数据中缺乏异构结构可能使得 GBDT 和 GNNs 的联合训练存在冗余现象。

考虑到 BGNN 模型在各种数据集上的性能明显优于强基线,因此团队又做了补充实验,测试了使用不同 GNNs 模型时对改进的影响。为了解开疑惑,团队比较了四种 GNNs 模型,分别是 GAT、GCN、AGNN 和 APPNP 模型。做法是先将这些模型分别替换为 Res-GNN 和 BGNN 模型,然后测量相对于原始 GNNs 性能的变化情况。实验结果如图 2 所示,其纵坐标表示每种 GNN 模型架构的 Rse-GNN 和 BGNN 之间的 RMSE 差距。通过实验结果,证明了所有经过测试的 GNNs 架构都能从本文团队所出的方法中受益匪浅。以 House 数据集为例,对于 GAT、GCN、AGNN 和 APPNP 四个模型,均方差分别减少了 9%、18%、19% 和 17%。

此外,还可以清晰的看到,BGNN 的端到端训练(红色方格)比 Res-GNN 模型中 CatBoost 和 GNN 的简单组合(黄色斜纹)带来的改进更大。这再次有力的证实了团队所提出方法的有效性。

先前的实验在模型性能方面给出了证明,那么在时间效率方面如何呢?

为了回答这个问题,团队成员分别测量了每个模型自开始训练到收敛的准确时间,具体结果呈现在表 4 中。很明显,大多数情况下,BGNN 和 Rse-GNN 的运行速度都要比 GNNs 快。这也就说明,BGNN 和 Rse-GNN 模型比 GNN 更加有效,在提高性能的基础上并不会增加时间成本。例如,对于 VK 数据集,BGNN 和 Rse-GNN 的运行速度分别比 GNN 快 3 倍和 2 倍。

最后做下总结,本文提到的新颖方法 BGNN,是一种端到端的方法,可以与任何消息传递神经网络和梯度增强方法结合使用。它首先利用 GBDT 构建异构数据常见的超平面决策边界,然后借助 GNNs 使用关系信息来提升预测。最终通过大量的实验证明,BGNN 在预测精度和训练时间方面均优于现有的方法。作者提示,可以将此方法扩展到图级别的预测任务上,如图分类或子图检测等有前景的方向上。


往期精彩回顾



适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑
本站qq群704220115,加入微信群请扫码:

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值