ICML2019 Self-Attention Graph Pooling

本文转载自知乎专栏: NLP入门论文解析 

作者: Taki

https://zhuanlan.zhihu.com/p/104837556

小编语: 经典的CNN架构通常包含卷积层和池化层. GNN将CNN泛化到了图数据上,在很多领域得到了广泛的应用.但是,目前的GNN主要关注如何定义节点的邻居并聚合邻居信息(也就是卷积层的设计),对于池化并没有较多的关注. 那么,图上的池化层该怎么来做呢?有什么独特的挑战呢? 

Links:arxiv.org/abs/1904.0808

Github:github.com/inyeoplee77/

Conference:ICML2019

0 Abstrcat

将深度学习的框架迁移到结构化的数据近期也是个热点,近期的一些学习都在将卷积+pool迁移到非结构化的数据中【模仿CNN,目前LSTM,GRU还没有人模仿】,将卷积操作迁移到Graph中已经被证明是有效的,但是在Graph中downsampling依旧是一个挑战性的问题,上篇论文也研究了Graph Pool【link】。在这篇论文,我们提出了一种基于self-attention的graph pool方法,我们的pool方法包括node feature/graph topology两个特征,为了公平,我们使用一样的训练过程以及模型结构,实验结果证明我们的方法在Graph classification中表现的好。

1 Introduction

深度学习的方法在数据的识别和增强等方面有突飞猛进。特别,CNNs成功地挖掘了数据的特征,例如 欧几里得空间的images,speech,video。CNNs包括卷积层和下采样层,卷积和池化操作都有 shift-invariance特性(也叫作 stationary property)。因此CNNs只需要很小的参数就可以获得较好的结果。

在很多领域,然而,大量的数据,都是以非欧数据的形式储存的,例如 Social Network,Biological network,Molecular structure都是通过Graph中的Node以及Edge的形式表示的,因此很多人尝试将CNN迁移到非欧空间数据上。

在Graph Pool领域,方法远远少于Graph Convolution,之前的方法只考虑了Graph topology,还有其他的方法,希望获得一个更小点的Graph表示,最近,也有一些方法希望学习Graph的结构信息,这些方法允许GNNs用一种End2End的方法。然而,这些池化方法都还有提升空间,譬如需要立方级别的存储复杂度,Gao & Ji 的方法[link]解决了复杂度的问题,但是没有考虑图的拓扑结构。

由此我们提出了一种 SAGPool模型,是一种 Self-Attention Graph Pooling method,我们的方法可以用一种End2End的方式学习结构层次信息,Self-attention结构可以区分哪些节点应该丢弃,哪些应该保留。因为Self-attention结构使用了Graph convolution来计算attention分数,Node features以及Graph topology都被考虑进去,简而言之,SAGPool继承了之前模型的优点,也是第一个将self-attention 加入Graph pooling中,实现了较高的准确度。

2 Related Work

2.1 Graph Convolution

Graph卷积在之前文章里说了很多了。主要是Spectral/Spatial两个领域,Spatial领域的卷积主要是直接在Graph中计算,中心节点通过邻接矩阵接受邻居节点的信息。Hamilton et al提出了GraphSAGE,通过采样传递信息学习节点的embedding。但是GraphSAGE是在一个固定大小的邻居域内操作,而GAT模型是根据attention结构,计算节点所有的邻居,然后信息传递直到稳态(stationary)。

2.2 Graph Pooling

Pooling layer让CNN结构能够减少参数的数量【只需要卷积核内的参数】,从而避免了过拟合,为了使用CNNs,学习GNN中的pool操作是很有必要的,Graph pool的方法主要为三种:topology based,global,hierarchical pooling。

Topology based pooling。早先的工作使用Graph coarsening 算法,而不是神经网络。谱聚类方法使用特征值分解来获得粗化图,但是特征值分解是一个花费算力的过程,Graclus不需要特征值,使用介于general spectral clustering objective and a weighted kernel k-means objective方法。

Global pooling。不想之前的方法,global pool方法考虑Graph feature,使用summation or neural networks在每一个layer pool所有的节点表示,不同结构的图都适用,因为这个方法一次性处理所有的representations。Gilmer et al. 将GNNs看做信息传递的结构,提供了一个通用的结构,可以使用Set2Set来初始化进而获得Entire Graph的表示。SortPool根据在结构中的角色来给节点的embedding 分类,然后将排序后的embedding传递到下一层。

Hierarchical pooling。Global pooling方法仅仅学习了节点的层次信息,这对捕捉Graphs的结构化信息是非常重要的,hierarchical pooling methods的主要动机在每一个layer建立一个可以学习feature-或者topology-based 节点表示。Ying et al 提出了一种 DiffPool,这是一种使用End2End fashion的新方法,在layer L有一个可学习的assignment matirx,  包含layer L中所有节点聚类到layer L+1上的概率,其中  代表L层节点总数,特别地,Nodes通过以下方程确定:

其中 X代表Node 特征矩阵,A是邻接矩阵。

Cangea et al 初始化 gPool,实现了与DiffPool类似的效果,gPool需要储存复杂度为  ,其中DiffPool需要  ,其中  ,  包括顶点,边,pooling比例  ,gPool使用一个可学习的向量  来计算projection scores,然后使用这个scores来选择top ranked nodes【前几个节点】。Projection scores通过向量p以及所有节点的特征的dot product。这个scores表明节点的信息保留程度,下面的方程大致表示 pooling produce in gPool:

正如方程2所示,graph topology不影响projection scores。

为了更好提高graph pooling,我们使用SAGPool,可以使用node feature/graph topology来实现pool。

3. Proposed Method

SAGPool的key points是使用GNN来计算self-attention scores。在3.1结构我们详细叙述了SAGPool的结构以及它们的变体。3.2中详细解释了模型的结构,SAGPool layer和模型结构在Fig 1和Fig 2中展示出来。

3.1 Self-Attention Graph Pooling

Self-attention mask。Attention结构已经在很多的深度学习框架中被证明是有效的。这种结构让网络能够更加重视一些import feature,而少重视一些unimportant feature。特别,self-attention通常也被叫做intra-attention,允许输入的features由attention self来评判,例如,如果使用Kipf&Welling的卷积公式,self-score  用下面的式子来计算:

其中  是激活函数,  是带self-connection的邻接矩阵(  ),  是  的度矩阵,  是N个节点,F维特征的Graph的input feature,  是SAGPool layer的唯一参数。通过初始化Graph convolution来获得self-attention scores,pooling的结果是基于graph features和topology结构。我们类似Gao & Ji;Cangea et al,纵然Graph很小【Node很少】依旧可以保持部分节点。Pooling比例  是一个保持节点数量的超参数,根据Z的值来选择前[kN]个节点:

其中 top-rank是一个返回前[kN]个数据的序号,  代表取index操作,  是feature attention mask。【top-rank(Z,[kN])指的是用Z指标来排序,选取前kN个节点】

Graph pooling 输入的Graph根据Fig 1中提到的mask操作一样操作

其中  是一个row-wise(i.e. node-wise)indexed feature matri,  是broadcasted elementwise product【将Z按照列扩增,然后点乘】,  是row-wise以及col-wise indexed邻接矩阵,  以及  是新的feature matrix以及对应的邻接矩阵。【  】

Variation of SAGPool 在SAGPool中使用Graph convolution的主要原因是想利用节点特征的同时利用好图的拓扑结构。如果将node feature以及邻接矩阵作为输入,公式3可能有不同的版本。计算attention score  的公式如下:

其中X标记着节点特征矩阵,A是邻接矩阵。

有很多方式来计算attention scores,可以不仅仅使用相邻节点,还使用multi-hop connected nodes。在方程7+8中,我们说明了使用2-hop 相连节点的例子,可以使用边的增强或者GNN layers的堆叠两种方法。

将邻接矩阵的平方加进去可以引入2-hop信息【邻接矩阵的平凡=2-hop graph的邻接矩阵】:

堆叠使用GNN layers可以间接通过2-hop节点传播信息。在这种情况下,非线性以及SAGPool的参数数量增加了。

另外一个思路就是将多个attention scores求平均,通过M个GNN计算得到下面的分数:

【说白了就是沿用 multi-head attention的思路】

在这篇论文,使用方程 7/8/9的称之为  。

3.2. Model Architecture

根据 Lipton & Steinhardt,如果同时修改了一个模型的多处,那么很难看出是哪些改动对模型起了促进作用【这都能引用论文,真的是哲学】。为了公平竞争,使用Zhang et al和Cangea et al的论文的模型结构,用相同的结构测试我们的方法与baselines。

Convolution Layer 如2.1部分讲的那样,Graph convolution有很多定义,使用其他类型的图卷积可以提高性能,为了统一,使用Kipf & Welling 的方法用于所有的Model。方程 10类似方程3,除了维度  。

其中  是l-th layer的节点表示,  是卷积权重,输入维度为F,输出维度为F',使用ReLU(Nair & Hinton,2010)作为激活函数。

Readout layer 由JK-net结构,使用一个readout layer用节点特征生成一个固定大小的表示,使用readout layer的输出如下所示:

其中N是节点数量,  是节点i的特征向量,||代表连接操作。、

Global pooling architecture 使用一个Global pooling结构,如Fig 2所示

Global pooling结构包含三个卷积层,将它们的输出连接起来。Node features在readout layer+pooling layer之下流动,Graph feature representions之后传输到线形层做分类。

Hierarchical pooling architecture 在这个设置下,如Fig 2所示那样,做一次卷积,做一次pooling,最后将三次pooling的结果加起来使用MLP来分类。

4. Experiments

使用Graph分类问题来测试global pooling/hierarchical pooling方法。在4.1讨论了evaluation的数据集,4.3描述了我们怎么训练模型,结果的比较在4.4/4.5中描述。

4.1. Datasets

Graph中节点超过1K的benchmark datasets共有5个节点,datasets的数据在Table 1中描述。

4.3 Training Procedures

Shchur et al 证明数据集的不同分割可能会影响GNN模型的表现。在我们的实验里,我们使用10-fold cross validation,使用20个随机种子,每个数据集一共得到200个实验结果。在训练的时候,10%的训练数据用来做验证。使用Adam optimizer,early stoping。如果每50轮后validation准确率没有明显提升就停止训练,最多训练100K轮次。使用grid search的方式来确定参数。

4.4 Baselines

我们考虑以下四种pooling方法:Set2Set,SortPool,DiffPool,gPool。DiffPool,gPool以及  使用 hierarchical pooling结构,Set2Set, SortPool以及  使用global pooling 结构。我们对所有的baselines/SAGPool,参数在Table 2总结了。

Set2Set 需要一个额外的超参数,也就是LSTM的processing steps的数量。所有的实验使用10 processing step。我们假设 readout layer是非必要的,因为LSTM 模型生成的Graph的embedding是不保序的。

SortPool 是最近提出的一个使用sorting来pooling的方法。当设置节点数为K的时候,60%的图都会有超过K的节点,在Global pooling setting,  与SortPool有一样的K个输出节点。

DiffPool 是第一个可训练的End2End graph pooling方法,可以产生Graph的Hierarchical representions。我们没有在DiffPool上没有使用batch normalization,这与pooling method是不相关的。对于超参数search,pooling比例从0.25-0.5变化,在reference implementation,cluster size的上限设置成节点总数的25%。  当pooling rate超过0.5的时候会导致out of memory。

gPool 选择top-ranked 节点,这与我们的方法是类似的,区别在于我们的模型考虑了图的拓扑结构,这有助于提高在分类任务上的结果。

4.6 Summary of Results

结果是如Table 3/Table 4所示。可以看出SAGPool表现不错。

5. Analysis

在这个部分,我们提供了实验结果的分析。

5.1 Global and Hierarchical Pooling

很难说哪种结构更好,POOLg最小化了信息的loss,在节点数较少的数据集【NCI1,NCI109,FRANKENSTEIN】上表现更好,然而,POOLh在节点数多的数据集上表现更好【D&D,PROTEINS】,因为POOLh在提取large scale graphs信息方面更拿手。不过,SAGPool比其他的模型更好。

5.2 Effect of Considering Graph Topology

SAGPool使用的是第一阶近似,这也就允许SAGPool考虑图的拓扑结构。正如Table 3,考虑图的拓扑结构表现更好。此外,graph Laplacian不需要重新计算,因为可以从相同位置的Graph Convolution layer传递过来。当SAGPool和gPool一样的参数,SAGPool在图分类问题上表现更好。

5.3 Sparse Implementation

邻接矩阵经常是稀疏的,所以Manipulating graph data with a sparse matrix非常重要。使用sparse martix可以降低内存复杂度和时间复杂度。

5.4 Relation with the Number of Nodes

在DiffPool中,cluster zise需要重新定义,因为GNN使用了matrix S。cluster size需要设置合适,否则会导致2个问题。1)参数的数量和最大节点数是相关的。

2)当Graph的大小不一的时候,很难决定合适的cluster size。

5.5 Comparison of the SAGPool Variants

省略

5.6 Limitations

我们在模型中使用了pooling ratio k,来处理不同的various size。在SAGPool,我们难以量化这个比例K,为了解决这个问题,我们使用二分类来觉得节点的保留与否,但是这没有解决问题。

6 Conclusion

老生常谈。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值