将边缘特征纳入图神经网络以预测国家 GDP

将边缘特征纳入图神经网络以预测国家 GDP

图神经网络 (GNN) 是一种非常灵活的技术,可以应用于各种领域,因为它们可以推广假设更严格数据结构的卷积和序列模型。在本文中,我们使用基于注意力机制的 GNN,根据国际贸易流量和特定国家的数据预测特定年份各个国家的 GDP。

我们在本文中描述了我们的方法,但也欢迎您随时关注我们的Google Colab 。

设置、数据源和预处理

任何图形机器学习问题的第一步都是了解所考虑的任务类型。由于图形非常灵活,我们在决定什么是边、什么是节点以及什么是我们期望的结果时有很多自由度。然而,在我们的例子中,设置相对简单:

  • 节点是国家
  • 节点特征是国家级统计数据,例如人口
  • 贸易伙伴之间的边缘
  • 边缘特征是交易的特定商品的数量

我们期望的结果也很简单,因为我们在对连续随机变量(即 GDP)进行节点级预测。实际上,我们正在学习 GDP 的预测因子,我们可以使用均方误差对其进行评估。

为了收集数据,我们依赖两个来源。第一个是BACI [1],这是法国全球经济研究所 CEPII 的一个项目,它记录了自 1995 年以来所有国家之间的双边贸易流量。该数据集非常庞大且详尽,因为它包含所有国家对的产品级信息。第二个数据来源是世界银行,它作为致力于发展的国际机构,收集了全世界的国家级统计数据。我们将此来源用于我们的标签(GDP)以及我们添加到模型中的三个节点级特征:就业水平、通货膨胀率和人口。

  • 国内生产总值:https://data.worldbank.org/indicator/NY.GDP.MKTP.CD [2]
  • 人口:https://data.worldbank.org/indicator/SP.POP.TOTL [3]
  • 失业率:https://data.worldbank.org/indicator/SL.UEM.TOTL.ZS [4]
  • 居民消费价格指数:https://data.worldbank.org/indicator/FP.CPI.TOTL.ZG [5]

方法

图神经网络概述

在深入研究任何代码之前,回顾一下 GNN 的实际工作原理会很有帮助。在最简单的层面上,GNN 尝试将节点、边或图表示为向量,然后可以将其用于传统的下游机器学习工具,例如多层感知器。理想情况下,我们希望这些向量表示与原始图具有某种有意义的关系。我们假设,通过结合节点级特征、边级特征和节点关系,我们可以学习与 GDP 相关的表示。GNN 通过一种称为“消息传递、聚合和更新”的技术学习将国家映射到此类向量表示。

img

图 1:GNN 使用节点的特征及其与其他节点的关系来找到合适的向量表示。左图:Zachary 的空手道俱乐部网络 [6],该图包含 34 个节点,分为两个类别之一:“Mr. Hi”和“Officer”。中图:GNN 的简化图。右图:Zachary 的空手道俱乐部中节点的向量表示。请注意不同类别节点之间的分离。

与任何神经网络一样,GNN 由多层组成。在每一层,每个节点通过收集、转换然后聚合其直接邻居的向量表示来更新其向量表示。例如,一个节点可能会对其相邻节点的特征执行线性变换,对得到的乘积求和,然后在最后一步应用 LeakyReLU 激活。结果将是该节点的更新向量表示,它将应用于 GNN 的下一层。用更数学的方式表达,给定我们 GNN 层处的节点 v 的向量嵌入𝒽ᵥˡ以及该节点的局部邻域𝒩(v),节点 v 在 GNN 下一层的嵌入将是:

img

在上面的更新方程中,是可学习的权重,𝒽ᵤˡ是节点****v的邻居u在层的嵌入,σ是非线性激活函数,例如 LeakyReLU。我们重复此过程,直到对模型的表达能力感到满意为止。我们注意到𝒽ᵥ ⁰ 只是我们上面描述的节点****v的初始节点特征。

img

图 2:GNN 中单个节点的消息传递、聚合和更新的可视化[7]

图注意力网络和边缘特征

我们刚刚解释了一个简单的、原始的 GNN 是如何工作的。在我们的交易流问题中,我们使用了图注意力网络(GAT) [8]。借助 GAT,我们添加了两个新功能,为我们的模型提供了更强的表现力。首先,我们允许每个节点决定哪个邻居更重要;我们通过附加到每个节点v的邻居的注意力参数来实现这一点。与Wᵏ一样,这种注意力权重是可以学习的。将此注意力参数添加到上面的更新方程中,我们得到:

img

其中,我们可学习的注意力参数是αᵤᵥ (从****uv 的注意力权重)。有关如何计算此注意力参数的更多信息,您可以查看原始论文

GAT 所包含的第二个令人兴奋的功能是边缘特征。这非常有用,因为我们的贸易流数据集不仅记录了哪些国家之间进行贸易,还记录了它们之间的贸易内容和数量。具体来说,我们可以查看数据中全球交易量最大的 10 种商品,以及这两个国家之间交易的这些商品的数量。使用 GAT,我们可以包含此类贸易统计数据并将其编码为我们的边缘特征。这些边缘特征究竟是如何融入模型的超出了本文的范围,但感兴趣的读者可以通过阅读 Pytorch Geometric 的GAT 文档了解更多信息!

img

图 3:图注意力网络 (GAT) 在计算节点嵌入时会整合局部邻域信息。为了计算节点 v 在层ℓ的嵌入****hᵥˡ⁺¹,节点v使用权重对其四个邻居的嵌入进行线性变换,应用每个邻居的注意力机制αᵤᵥ,汇总结果,然后应用非线性激活σ

用于 GDP 预测的 GAT

现在我们已经掌握了基础知识,让我们继续为 GDP 预测任务实施 GAT。我们将实施以下流程:

img

图 4:模型架构管道。请注意,最后的 ReLU 是为了确保 GDP 为非负值。

首先,让我们通过读取 BACI 贸易流和世界银行数据集并分别提取边特征和节点特征来加载和预处理我们的数据集:

接下来,我们使用 PyG 的GATConv 模型构建一个简单的两层图卷积网络,后面跟着一个线性层。这个“基线”模型(即不包含任何边缘特征的 GAT)仅依赖于每个国家的单个节点特征和贸易流。

为了进行比较,我们将构建一个包含边缘特征的单独模型*。*我们将比较这些模型,看看哪一个表现最好:

现在,我们训练两种模型并比较它们的性能:

img

图 5:比较 GAT 模型与基线的 log(loss) 值。GAT 模型略胜基线损失

img

img

图 6:2000 年之后模型预测准确率不断提高,显示出准确预测 GDP 的趋势

结果

首先,很明显这两个模型都在随着时间的推移而学习。如图 5 所示,验证数据集上的 MSE 下降得相当平稳,直到大约 2000 次迭代。这表明学习信号正确地通过了模型,并且可以推广到看不见的验证集。请注意,y 轴是对数缩放的,以显示模型正在取得的微小但稳定的收益。

第二个观察结果是,具有边缘特征的模型略微但始终优于不具有边缘特征的基线模型。验证 MSE(针对两个模型在同一数据集上计算)对于具有边缘特征的模型较低,这表明我们的边缘特征(即商品交易量)增加了有关一个国家经济规模的信息。

这是个好消息。让我们继续检查模型的一些缺点,并讨论进一步改进的方法。首先,标签的对数缩放可能会使 MSE 看起来比实际更令人印象深刻。在训练结束时,对于具有边缘特征的模型,验证集上的 MSE 约为 5.4。这意味着每个预测的预期值与真实值相差约 2.3。虽然如果模型与一个国家的真实 GDP 相差 2.30 美元会令人难以置信,但请记住,我们对数据进行了对数缩放,因此损失意味着模型与正确值相差约 2.3 个数量级。结果是我们不会在短期内让世界银行失业。

第二点在我们分析模型如何提高其准确性时就出现了。请看图 6,它显示了模型在初始化时和经过 2000 次训练后对测试数据的预测值与实际值。这两个图表明,模型准确性的提高很大程度上可以归因于简单地学习平均 GDP 并将所有预测值上调相应的量。好消息是,预测值与实际值之间存在约 0.27 的正相关性,这表明预测除了世界 GDP 的平均值之外还包含了更多信息,但这个数字进一步强调,我们不应该因为看似较低的 MSE 而宣布胜利。

下一步该怎么做?我们尝试对模型进行超过 2000 次迭代的训练,但验证准确率开始下降,这表明我们可能过度拟合了。我们还尝试了不同的超参数和模型架构组合,但没有显著效果。最简单且可能最好的方法就是在节点或边缘级别添加额外的特征。虽然如果特征的维度超出数据集的大小,可能会面临过度拟合的风险,但这可能会提高模型的准确性。后续工作可以研究额外特征及其维度对性能的影响。

结论

单且可能最好的方法就是在节点或边缘级别添加额外的特征。虽然如果特征的维度超出数据集的大小,可能会面临过度拟合的风险,但这可能会提高模型的准确性。后续工作可以研究额外特征及其维度对性能的影响。

结论

我们学习了如何实现 Graph Attention 网络,这是一种强大的架构,允许节点了解其邻居的消息有多重要。此外,我们通过添加边缘特征来提高 GAT 的性能,这些特征编码了有关两国之间某些商品贸易量的信息。使用边缘特征可以进一步提高我们模型的表达能力,并利用我们数据中可能存在的其他模式。它们不仅可以应用于我们的贸易网络,还可以应用于许多其他应用。

博客原文:专业人工智能社区

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值