声明
本篇论文的主要内容来自于斯坦福大学的博士生Rex Ying,论文名称为:Hierarchical Graph Representation Learning withDifferentiable Pooling。论文地址:点击下载。但需要说明的是本篇文正并不是对论文的翻译,书中大部分内容是作者对论文理解,当然可能个人水平有限,中间难免会出现一些错误,如若发现恳请指出,不胜赐教。
背景
近年来人们对卷积神经网的研究越来越热门化,其成果被广泛应用于计算机视觉、自然语言处理等诸多领域。但是人们深入研究的过程中也发现了cnn的诸多不足。其中最大的局限性之一就是无法进行因果推理。对于该问题工业界一直在进行积极探索,其中一个很有前景的方向就是图神经网络,简称为GNN。GNN在对图进行处理的过程中将底层的输入图作为一个计算图,通过在图中船体、转化和聚合特征节点信息来学习生成单个节点嵌入。然后将生成的节点嵌入用作微分预测层的输入,例如用于节点分类或者链接预测。然而当前GNN有一个极大的限制就是其处理过程是平面化的即信息的传递更多是在边上进行,而不是以层级方式推断和聚合信息。但是对于图分类认为来说,层级结构的确实确实是一个比较严重的问题,因为该类任务主要是用于预测出整个图相关的标签。如果对该类任务用传统的GNN方式进行处理,它会对图中所有的节点嵌入进行全局池化,这种全局池化的方式忽略了图中存在的层次结构,不利于生成有效的GNN模型。因此作者在本篇论文中提出了“基于差分池化的分层图表示方式”,这是一个可以分层和端到端的方式应用于不同图神经网络的可微图池化模块。DIFFPOOL允许开发可以学习在图的层级表征上运行的更深度的GNN模型。他们开发了一个和CNN中的空间池化操作相似的变体,空间池化可以让深度CNN在一张表征越来越粗糙的图上迭代运行。与标准CNN相比,GNN的挑战在于图不包含空间局部性的自然概念,也就是说,不能将所有节点简单地以[m×m patch]的方式池化在一张图上,因为图复杂的拓扑结构排除了任何直接、决定性的[patch]的定义。此外,与图像数据不同,图数据集中包含的图形节点数和边数都不同,这使得定义通用的图池化操作更具挑战性。
算法概述
该算法通过通过上一层的节点嵌入将本层节点映射为一组堆叠,然后生成的堆叠作为下一组的节点嵌入,一次类推。以下图为例
第一层通过上一层的节点嵌入生成一组堆叠,然后该组堆叠作为第二层的节点嵌入,当然该图表示结构只有三层,实际的GNN模型可能会有十几层。在这种方法的处理下,每处理一层图就会越粗化,并且通过该方法训练之后可以产生任何输入图的层级表征。
接下来我们对算法的具体过程进行描述,开始之前我们首先要明确两个概念,第一个是训练集,论文中作者选择的训练集命名为G,其中G1、G2、G3都有两部分组成即A和F,A表示的是图的邻接矩阵,F表示的节点特征矩阵。而y1、y2、y3都输入分类集合y代表的是某个具体的分类标签。我们的目标是训练处一个模型,该模型通过给出其输入图,然后输出其对应的分类标签。
在论文中作者选择的是GNN中的
message passing
方式作为GNN模型,该模型的具体构成如下图所示
通过多次训练逐步找出训练参数,然后得出训练模型。上边的公式我们可以看到其核心是传播函数M,针对M的选择有好多种,作者在论文中选择的是将线性变化和RelU非线性激活函数结合起来而形成的一个传播函数。该函数的具体形式如下所示
比较复杂如果大家感兴趣可以查相关文献,这里就不在进行详述了。作者本篇论文中给出的基于分配学习的可微池化方法抽象说就是,通过给定第L层的节点嵌入以及第L层的邻接矩阵,产生第L+1层的邻接矩阵以及第L+1层的粗化图的特征矩阵。具体讲通过下边两个公式,第一个公式产生第L+1层的粗化图特征矩阵,第二公式产生第L+1层的节点嵌入。大家注意作者为了得到第L层的节点嵌入和第L层软分配这里他用到了两个GNN模型(两个模型是不同的)以及softmax函数。
实验验证
作者为了证明自己方法的优越性设计了一系列实验来进行验证,他在论文中给出了三个需要验证的问题也就是实验问题。
- Q1:与其他已提出的GNN池化方法相比,DIFFPOOL如何?
- Q2:与现有最好的图分类任务模型相比,结合了DIFFPOOL的GNN如何?
- Q3:DIFFPOOL对输入图给出了有意义且可解释的簇吗?
为了保证方法的泛化能力,作者使用多种图的分类基准数据包括蛋白质数据集、社交网络数据集以及科学协作数据集作为输入。设计了一个多层的神经网络模型,其中GNN部分选择的是GCN的变体CRAPHSAGE作为GNN模型,其中每两个GNN层都加上一层微分池化层。通过实验作者也给出了实验结果
从这张图函数我们可以看到微分池化方法在GNN池化方法中获得了最高的平均性能值,其中四个基准测试中有四个达到了最优。因此我们可以的出结论与其他GNN池化方法相比,微分池化方法表现是最好的。针对第三个问题的回答。作者给出了COLLAB数据集的前两层节点分配的可是可视化图,我们可以看出微分池化可以学习稀疏图中存在的有意义的部分,并且可以对输入图给出有意义且可解释的簇。