GNN教程:与众不同的预训练模型!

本文介绍了图神经网络(GNN)的预训练框架,通过边重建、Centrality Score Ranking和保留图簇信息等任务,学习通用的图结构特征,以解决标注数据稀缺的问题。预训练的GNN可以增强下游任务的表现,提供更丰富的节点表征。
摘要由CSDN通过智能技术生成

↑↑↑关注后"星标"Datawhale

每日干货 & 每月组队学习,不错过

 Datawhale干货 

作者:秦州,算法工程师,Datawhale成员

0 引言

虽然 GNN 模型及其变体在图结构数据的学习方面取得了成功,但是训练一个准确的 GNN 模型需要大量的带标注的图数据,而标记样本需要消耗大量的人力资源,为了解决这样的问题,一些学者开始研究Graph Pre-training的框架以获取能够迁移到不同任务上的通用图结构信息表征。

在NLP和CV领域中,学者已经提出了大量的预训练架构。比如:BERT(Devlin et al., 2018)和VGG Nets (Girshick et al., 2014),这些模型被用来从未标注的数据中学习输入数据的通用表征,并为模型提供更合理的初始化参数,以简化下游任务的训练过程。

后台回复【GNN】进图神经网络交流群。

这篇博文将向大家介绍图上的预训练模型,来自论文Pre-Training Graph Neural Networks for Generic Structural Feature Extraction 重点讨论下面两个问题:

  1. GNNs 是否能够从预训练中受益?

  2. 设置哪几种预训练任务比较合理?

1 预训练介绍

本节将向大家介绍什么是模型的预训练。对于一般的模型,如果我们有充足的数据和标签,我们可以通过有监督学习得到非常好的结果。但是在现实生活中,我们常常有大量的数据而仅仅有少量的标签,而标注数据需要耗费大量的精力,若直接丢掉这些未标注的数据也很可惜。因此学者们开始研究如何从未标注的数据中使模型受益。

一个简单的做法是我们自己为这些未标注数据"造标签",当然这些标签和我们学习任务的最终标签不一样,否则我们也不用模型学习了。举个简单例子,比如我们想用图神经网络做图上节点的分类,然而有标签的节点很少,这时候我们可以设计一些其他任务,比如利用图神经网络预测节点的度,节点的度信息可以简单的统计得到,通过这样的学习,我们希望图神经网络能够学习到每个节点在图结构中的局部信息,而这些信息对于我们最终的节点分类任务是有帮助的。

在上面的例子中,节点的标签是我们最终想要预测的标签,而节点的度是我们造出来的标签,通过使用图神经网络预测节点的度,我们可以得到1)适用于节点度预测的节点embedding 2)适用于节点度预测任务的图神经网络的权重矩阵,然后我们可以1)将节点embedding接到分类器中并使用有标签的数据进行分类学习 2)直接在图神经网络上使用有标签的数据继续训练,调整权重矩阵,以得到适用于节点分类任务的模型。

以上就是预训练的基本思想,下面我们来看图神经网络中的预训练具体是如何做的。

2 GCN 预训练模型框架介绍

如果我们想要利用预训练增强模型的效果,就要借助预训练为节点发掘除了节点自身embedding之外的其他特征,在图数据集上,节点所处的图结构特征很重要,因此本论文中使用三种不同的学习任务以学习图中节点的图结构特征。通过精心设计这三种不同任务,每个节点学到了从局部到全局的图结构特征,这三个任务如下:

  • 边重建:首先mask一些边得到带有噪声的图结构,训练图神经网络预测mask掉的边;

  • Centrality Score Ranking:通过对每个节点计算不同的 Centrality Score,其中,包括:Eigencentrality, Betweenness, Closeness和 Subgraph Centrality;然后,通过各个 Centrality Score 的排序值作为label训练 GCN;

  • 保留图簇信息:计算每个节点所属的子图,然后训练 GNNs 得到节点特征表示,要求这些节点特征表示仍然能保留节点的子图归属信息。

整个预训练的框架如下图所示,首先从图中抽取节点的结构特征比如(Degree, K-Core, Clustering Coefficient等),然后将这些结构特征作为embedding来学习设定的三个预训练任务,label使用的是从图中抽取的各个任务对应的label,最后得到节点embedding表征接到下游的学习任务中。

Screen Shot 2019-07-15 at 20.53.02

图注:应用 GCN 作为子模块的图预训练框架

2.1 预训练任务介绍

任务 1:边重建

任务 1 的思路是这样的,首先,随机删除输入图  中一些已存在的边以获得带有噪声的图结构  ;然后, GNN 模型使用

作为输入,记作编码器 ,学习到的表征信息输入到 NTN 模型中,NTN 模型是一个解码器,记作 ,以一对节点的embedding作为输入,预测这两个节点是否相连:

其中, 和  采用二元交叉熵损失函数进行联合优化:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值