Kipf博士论文导读《Deep Learning With Graph-Structured Representations》(《使用图结构表示的深度学习》)

1 INTRODUCTION

1.1 STRUCTURE AND HUMAN COGNITION

我们的生活中存在着非常多的结构(Structure),例如原子、分子、社交网络等,然后就很自然地引出一种表达这些结构的形式,也就是图(Graph)。

1.2 Artificial Intelligence and Deep Learning

简单的介绍了下人工智能和深度学习。

1.3 SCOPE AND RESEARCH QUESTIONS

全文主要分为两部分:第一部分,使用明确的图结构数据应用到深度学习的一系列学习任务上,也就是这些数据包括一系列节点以及节点之间的关系;第二部分使用的数据就不包含明确的图结构,需要学习一个模型,来推断或者充分利用数据中的隐藏结构。然后提出了几个问题引导出他所做出的贡献,也就是一系列模型,这些模型文章后面都会讲到。

2 BACKGROUND

然后是背景知识,文章应用的一些符号,简单地介绍了一下多层感知机。这里还提到了GNN的消息传递更新机制,因为这是很多图神经网络的基础,我觉得有必要讲一下。

公式如下:
h i , j = f e d g e ( h i , h j , x ( i , j ) ) , h i ′ = f n o d e ( h i , ∑ j ∈ N i h ( j , i ) , x i ) . h_{i,j}=f_{edge}(h_i,h_j,x(i,j)),\\ h_i^{\prime}=f_{node}(h_i,\sum_{j\in N_i}h_{(j,i)},x_i). hi,j=fedge(hi,hj,x(i,j)),hi=fnode(hi,jNih(j,i),xi).
image-20210919123146466

比如说对于1节点,用其节点隐藏嵌入(node representations h 1 h_1 h1 )与其所有相邻节点的隐藏嵌入,再加上边特征,如果有的话,分别计算它们的边的隐藏嵌入(edge representations),然后将计算的所有边的隐藏嵌入聚合更新得到新的节点隐藏嵌入 h i ′ h_i^{\prime} hi。公式里的 f n o d e , f e d g e f_{node},f_{edge} fnode,fedge可以是任意可微的函数例如神经网络。当然这只是一层消息传递更新,如果有两层的话,第二次更新时1节点再次聚合来自其领域节点的消息,而此时来自邻域的消息也是基于每邻域中节点自己的邻居信息的聚合,每个节点嵌入都包含了其2跳邻域的信息。这就跟卷积很相似了。

image-20210919124756158

2.4 LATENT VARIABLE MODELS

后面讲到变分自编码器的时候再讲。

2.5 CONTRASTIVE LEARNING

主要是损失函数的构建问题,利用一个正样本和K个负样本进行训练,通过人为破坏正样本得到负样本。

3 Graph Convolutional Networks for Semi-supervised Classification

指的是节点分类

切比雪夫图卷积的一阶近似,其卷积公式为:
H ( l + 1 ) = σ ( H ( l ) W 0 l + D − 1 2 A D − 1 2 H ( l ) W 1 ( l ) ) H^{(l+1)}=\sigma(H^{(l)}W_0^{l}+D^{-\frac{1}{2}}AD^{-\frac{1}{2}}H^{(l)}W_1^{(l)}) H(l+1)=σ(H(l)W0l+D21AD21H(l)W1(l))
GCN在此基础上更加简化,直接把两个参数矩阵变为一个,卷积公式为:
H ( l + 1 ) = σ ( ( I N + D − 1 2 A D − 1 2 ) H ( l ) W ( l ) ) H^{(l+1)}=\sigma((I_N+D^{-\frac{1}{2}}AD^{-\frac{1}{2}})H^{(l)}W^{(l)}) H(l+1)=σ((IN+D21AD21)H(l)W(l))
但是由于 I N + D − 1 2 A D − 1 2 I_N+D^{-\frac{1}{2}}AD^{-\frac{1}{2}} IN+D21AD21的特征值在[0,2]之间,会影响模型训练过程的稳定性,作者又做了个调整:
I N + D − 1 2 A D − 1 2 → D ~ − 1 2 A ~ D ~ − 1 2 I_N+D^{-\frac{1}{2}}AD^{-\frac{1}{2}} \rightarrow \tilde D^{-\frac{1}{2}} \tilde A \tilde D^{-\frac{1}{2}} IN+D21AD21D~21A~D~21
其中 A ~ = A + I N \tilde A = A+I_N A~=A+IN, D ~ i , i = ∑ j A ~ i , j \tilde D_{i,i}=\sum_j\tilde A_{i,j} D~i,i=jA~i,j,这就是最后的卷积公式。 A ~ ∈ R N × N \tilde A \in R^{N\times N} A~RN×N

举个例子,一个两层的GCN就可表示成:
Z = f ( X , A ) = s o f t m a x ( A ^ R e L U ( A ^ X W ( 0 ) ) W ( 1 ) ) Z = f(X,A) = softmax(\hat AReLU(\hat AXW^{(0)})W^{(1)}) Z=f(X,A)=softmax(A^ReLU(A^XW(0))W(1))
其中 X ∈ R N × d i n X\in R^{N\times{d_{in}}} XRN×din 是图的特征矩阵, N N N表示节点数, d i n d_{in} din是输入的特征维度, W ( 0 ) ∈ R d i n × d h i d W^{(0)}\in R^{d_{in}\times d_{hid}} W(0)Rdin×dhid是第一层的参数矩阵, W ( 1 ) ∈ R d h i d × d o u t W^{(1)}\in R^{d_{hid}\times d_{out}} W(1)Rdhid×dout是二层的参数矩阵。最后得到的 Z ∈ R N × d o u t Z\in R^{N\times d_{out}} ZRN×dout就是经过分类的节点。

然后就是损失函数:
L = − ∑ l ∈ Y L ∑ f = 1 d o u t Y l , f l n Z l , f \mathcal L = -\sum_{l\in \mathcal Y_L}\sum_{f=1}^{d_{out}}Y_{l,f}lnZ_{l,f} L=lYLf=1doutYl,flnZl,f
例如 Y = [0,0,0,1,0,0],Z = [0.1,0.1,0,0.7,0.1,0],这个就跟交叉熵差不多了。这个 Y ∈ R N × d o u t Y\in R^{N\times d_{out}} YRN×dout就是标签。

数据:

这三个数据集都是论文引用图数据,其中节点表示每一篇论文,边表示论文之间的引用或者被引用关系,Classes表示每篇论文属于哪个种类,以每篇论文的词袋向量作为Features。训练数据是从每个class里面选取20个数据组成。

image-20210922213546806

实验结果

image-20210919173828954

image-20210919173814995

再加上相关工作和实验部分

4 Link Prediction with Graph Auto-Encoders

任务:边关系的预测

图自编码器GAE:

上一章介绍了图结构数据的节点分类,这一章则是边关系的预测,例如预测两个节点间是否应该有一条边将其连接在一起。

通过一点改变就能将上一章介绍的GCN转换成这一章要介绍的用于边关系预测的图自动编码器 GAE,它是一种基于编码器-解码器结构的模型。

本章中,编码器是基于图的神经网络,将节点特征 X ∈ R N × d i n X\in R^{N\times d_{in}} XRN×din 作为输入,输出新的节点表示 Z ∈ R N × d o u t Z\in R^{N\times d_{out}} ZRN×dout,其中一般 d i n > d o u t d_{in}>d_{out} din>dout,这其实就是图的嵌入,也可以叫做降维。在这里呢,它实际上就是用了GCN作为编码器,就是上一章的GCN。

image-20210925134701933

解码器呢,本文中就是编码器输出的内积, Z Z T ZZ^T ZZT,通过其重构出图结构的连接性,就是图的邻接矩阵,然后让它和真实的邻接矩阵求交叉熵,也就是学习的损失函数。

image-20210925134726352

损失函数

image-20210925134744193

稀疏图的损失函数

image-20210925134828913

权重是正样本和负样本的比, w p o s = N 2 / ( 2 N p o s ) w_{pos} = N^2/(2N_{pos}) wpos=N2/(2Npos) w n e g = N 2 / ( 2 N n e g ) w_{neg} = N^2/(2N_{neg}) wneg=N2/(2Nneg) N p o s = ∣ ϵ ∣ N_{pos} = |\epsilon| Npos=ϵ是图中边的数量, N n e g = N 2 − N p o s N_{neg} = N^2 − N_{pos} Nneg=N2Npos是邻接矩阵A中0的数量

v2-d9c5e951f11f291f5ccb133a2891b4d0_1440w

变分图自编码器VAE:

v2-060c5793ba30c8f422f5bbf55f288ccb_1440w

image-20210925150555899

image-20210925150608672

image-20210925150618799

5 Modeling Relational Data with Graph Convolutional Networks

在前两章,作者介绍了在带有属性的无向图上的节点分类和边预测。在本章,作者将GCN的建模能力扩展到了多关系图,也就是节点和节点之间边可以有多种类型。多关系图最典型的例子就是知识图谱,但是目前所有的知识库都一个很致命的问题就是他们都是不完整的,有很多边关系和节点标签是缺失的,需要去补全。所以本章的学习任务就是利用基于GCN的自编码器去补全缺失的关系。

image-20211014113450575GCN的公式:

image-20211014114732889

R-GCN的公式:

image-20211014114746469

对比发现就是R-GCN对每种关系都单独用了一个权重矩阵。

R-GCN跟CNN有一点相似,卷积核对应图像上的每一个位置的像素点都是一种(特殊的)关系, c i , r c_{i,r} ci,r就等于1。

但是R-GCN这样对每种关系都单独用一个权重矩阵也是有问题的,例如当一张图的关系类型特别多的时候,训练参数就会急剧增加,很容易导致过拟合,作者应用了基数分解来解决这个问题:

将每一个权重矩阵分解,表示成B个参数共的矩阵的线性组合,其中基数V是在所有关系间共享的,只有a是跟每个关系有关的:

image-20211014121053919

image-20211014130907872

如果将R-GCN作为表示学习的一种方法,可以将其它应用到节点分类和边预测两种模型,节点分类跟第一章讲的GCN模型结构基本一致,只是将GCN换成了R-GCN。边关系预测模型在替换GCN的同时还换了解码器,将内积换成了DistMult,

image-20211014131459605

先将节点和关系通过浅层嵌入到同一维度,再求内积作为评分。

那么浅层嵌入又是什么?其实就是维护一个二维表(embedding dim, num embeddings),这个二维表也是一个训练矩阵,每一列对应一个结点的浅层嵌入,通过节点的编号获取嵌入,所以也叫做嵌入查寻。

image-20211014140316222

要注意的是,在这个场景下,输入数据是没有特征属性的,也就是以往存在的X矩阵是没有的,只有一张图。所以就需要对节点进行浅层嵌入来作为他们的属性。

边预测的整个流程就是:input -> Shadow Embedding -> R-GCN -> DistMult -> Loss

对输入的节点进行浅层嵌入,将得到的嵌入作为输入应用R-GCN,R-GCN就可以说是一种深度嵌入,再利用得到的嵌入计算DistMult评分(把代表关系的向量对角化),再计算交叉熵损失函数。

image-20211014141841797

6 Neural Relational Inference for Interacting Systems

本章要讲的模型是NRI,神经关系推断模型

任务:学习目标是从交互系统的数据中学习出潜在的网络结构,并且为这个动态的交互系统建模。

先介绍一下数据:

image-20211017193939139

image-20211017194509283表示所有N个粒子在时刻t的属性,image-20211017194614271表示粒子i所有时刻的属性,属性数据一共有四维,两个位置量和两个速度量。所以总的数据就是image-20211017194739367N个粒子在T个时刻内的轨迹数据。

所以具体的任务就是通过刚刚介绍的数据推断出这些粒子之间的关系,这里粒子之间的关系可以是多类型的,但是这里只有两个类型,类似于与一般的有向图,分出和入,并且学习出这个动态系统的模型。

image-20211017195331543

模型主要包括两部分,编码器和解码器,编码器的功能是利用输入的轨迹数据,推断出粒子之间潜在的连接关系。他将初始的图结构就设置为完全图,也就是任意两个节点之间都有连接,但是不包括自环。

CNN Encoder:

image-20211018125423477

NRI网络模型Encoder

首先将每条边对应的起点和终点属性拼接在一起,以边的视角来计算,然后带入CNN卷积层和MLP运算,CNN包括两个一维卷积层,MLP就是一般的多层感知机,可以结合之前讲的消息传递机制对应着看,这就相当于第一个公式,对每个结点,计算与他相邻的边的关系,这个f_edge函数,就相当于这里的CNN加MLP;然后这一步,就是聚合操作了,将每个节点连接的边聚合起来,带入一个MLP;然后又是一个CONCAT操作,不过这里还有个跳跃连接,将之前得到的数据再拼接到这里,经过一个MLP和全连接输出层,最后得到每条边的可能的关系。

image-20211018134020767

这个原理其实跟图注意力机制很相似,图注意力机制是用一条边的两个节点属性计算这条边的权重,在这里由于假设的是完全图,所以它可以计算所有边的权重,用这个权重来表示这条边存在的概率。

RNN Decoder:

image-20211018141025808

NRI网络模型Decoder

解码器的功能通过原始的轨迹数据和编码器得到的边关系数据对下一个时间步的轨迹数据进行预测,从而学习出动态系统的模型。

再看看模型图,因为包含时间维度,所以整个的模型是RNN的结构。

首先初始化hidden矩阵,经过拼接,然后这一步是对每一个不同的边类型使用不同的函数,然后求和。刚刚讲过,两点之间的可能有多种边类型,每一种类型有一个概率,就是刚刚编码器的输出,其实也就是不同边类型的邻接矩阵,只不过他不是0和1,而是概率,所以这一步就相当于用不同类型的边的邻接矩阵进行卷积,将每一次的输出求和,对应于第一个公式。然后再聚合,将聚合输出和第i个时间步的轨迹数据带入到GRU,输出的hidden作为下一个时间步的输入,输出的预测值带入MLP输出层运算后再加上原始的inputs就作为输出的i+1步的预测值。然后重复T-1次,得到1-T步的预测值,跟原始数据进行比较。

7 Compositional Imitation Learning and Execution

第五篇,组合模仿学习和执行,学习任务是发现序列数据中的组合结构。什么是序列数据中的组合结构?文章中给出了例子

image-20211022195556795

一个找东西的游戏,在一个10*10的网格中,一次只能上下左右移动一次,然后要找到指定的物品,一个总的任务可以分为多个子任务,例如这里分为三个子任务,每次找到一个,每个子任务又是由很多时间步骤组成。

这些所有时间步组合在一起就是原始的输入数据,模型的作用就是在这些时间步中找到每个子任务的边界(划一划)。模型将输入数据进行软分割到片段,再将每一个片段都映射成一个潜编码,然后再重建。这些片段就是要找的子任务。

模型的整体结构如图:

image-20211022201712241

模型也是基于编码器和解码器结构,首先用一个初始化全为1掩膜mask和原始数据带入一个LSTM编码器对数据进行嵌入,再用数据嵌入带入一个MLP多层感知机计算第一个子任务的大致边界,将嵌入数据用计算出的边界截断后计算任务的潜编码,这是模型的一个特点,推断某一个子任务时,只使用任务对应的一部分数据,然后用潜编码对子任务进行重建,重建的子任务就可以用于与原始数据进行对比计算损失函数,再用当前边界计算用于下一阶段的掩膜,也就是前一步使用的数据在后面的步骤中会被模型忽略,用余下的数据作为输入。重复这些步骤直到所有子任务边界被找到。

CompILE流程图

子任务的具体数量在这个模型中是作为一个已知值,而且这个模型只适用于子任务是同一类型的数据。

我不太明白这个模型跟图的关系是什么?我感觉好像没有直接关系。

image-20211023121046395

每一个数字相当于一个子任务。

8 Contrastive Learning of Structured World Models

最后一章是 通过对比学习构建结构化世界模型。还是通过数据来了解模型的具体任务。

image-20211029172807858 image-20211029172601434

在这样一个图像里,有五个图形对象,图形是5×5的网格组成,每个对象占一格,每个网格由10×10的像素构成,有RGB三个通道,所以就是一张50×50×3的图像。五个对象的位置关系被称为一个状态,state,然后有一个动作,action,作用在这个状态上,这个action是一个5×4的,也就是20维的one-hot向量,表示五个对象中的一个,在上下左右四个方向的其中一个上走一格,得到下一个状态。也就是右边这幅图,经过action后,一个蓝色三角往右移了一步。所以任务就是,以一个状态和一个动作作为输入,来预测下一个状态。但是关键是还有一个限制,如果下一步前进的方向上的位置已经被占了,或者要超出边界了,那呆在原地不动。比如说:左边图像中的绿色正方形要往左,但是他往左就超出边界了,那么它就不动;紫色圆要往上,但是上面被红色圆占了,那么他也不动。这些就需要用几个对象之间的位置关系去判断。

image-20211030153746819

再看一下具体的模型,首先把原始的图像数据带入一个CNN进行对象提取,就是把每个对象所在位置提取出来,然后将其带入一个MLP编码器,得到对象的隐表示,然后带入一个GNN,GNN用的是对象的全连接图,具体公式如下图:也是典型的消息传递,先计算每条边的隐表示,再聚合,不过聚合过程中还加上了,action,也就是动作。计算出的结果就是目前的状态到下个状态的增量,用它加上目前的状态值与下一个状态值对比计算损失函数,进行训练。GNN就是将对象之间的关系信息和动作信息结合在一起,从而判断下一步的状态。

他还可以训练一个解码器,将 z t z_t zt转换回图像,从而进行基于像素的对比。

image-20211030164750785

模型实现

论文中的模型基于Pytorch的实现可以看我的Github(也是我在Github上找的)

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值