DEEP GRAPH INFOMAX(DGI)
摘要
本文提出了DGI,一种以无监督的方式学习图结构数据中节点表示的一般方法。DGI依赖于最大限度地扩大图增强表示和目前提取到的图信息之间的互信息——两者都是使用已建立的图卷积网络体系结构导出的。对于图增强表示,是根据目标节点所生成的子图,因此可以用于下游节点的表示学习任务。与大多数以前使用GCNs进行无监督学习的方法相比,DGI不依赖于随机游走目标,并且很容易适用于直推式学习和归纳式学习。我们在各种节点分类基准上展示了竞争性能,有时甚至超过了监督学习的性能。
1 Introduction
将神经网络推广到图形结构输入是当前机器学习的主要挑战之一。虽然最近取得了重大进展,特别是图卷积网络,但大多数成功的方法使用监督学习,这往往是不可能的,因为大多数图表数据在野外是未标记的。 此外,从大规模图中发现新颖或有趣的结构往往是可取的,因此,无监督的图学习对于许多重要的任务是必不可少的。
目前,具有图结构数据的无监督表示学习的主要算法依赖于基于随机游走的目标,有时进一步简化以重建邻接信息。基本的直觉是训练编码器网络,使输入图中的“接近”节点在表示空间中也“接近”。
虽然功能强大,而且与传统的衡量标准(如个性化PageRank评分)相关,但随机游走方法受到已知的限制。最突出的是,随机游走目标以牺牲结构信息为代价过分强调邻近信息,并且性能高度依赖于超参数的选择。此外,随着基于图卷积的更强的编码器模型的引入,目前还不清楚随机游走目标是否真的提供了任何有用的信号,因为这些编码器已经强制产生了一种归纳式偏差,即相邻节点具有类似的表示。
在这项工作中,我们提出了一种用于无监督图学习的替代目标,这种目标是基于互信息,而不是随机游走。在概率论和信息论中,两个随机变量的互信息(Mutual Information,简称MI)是指变量间相互依赖性的量度。近年来基于互信息的代表性工作是 MINE,其中提出了一种 Deep InfoMax (DMI) 方法来学习高维数据的表示。具体来说 DMI 训练一个编码模型来最大化高阶全局表示和输入的局部部分的互信息(如果从 cv 的角度理解就是一张图片中的 patches)。这鼓励编码器携带出现在所有位置的信息类型(因此是全局相关的),例如类标签的情况。
2 相关工作
2.1 对比方法
对于无监督学习一类重要的方法就是对比学习,通过训练编码器使它在特征表示中更具判别性来捕获感兴趣的和不感兴趣的统计依赖性。例如,对比方法可以使用评分函数,训练编码器来增加“真实”输入的分数,并减少“假”输入的分数,以此判别真实数据和假数据。有很多方法可以对一个表示进行打分,但在图形文献中,最常见的技术是使用分类,尽管也会使用其他的打分函数。DGI在这方面也是对比性的,因为DGI目标是基于对局部-全局对和负抽样配对的分类。
2.2 抽样战略
对比方法的一个关键实现细节是如何绘制正负样本。关于无监督图表示学习的先前工作依赖于局部对比损失(强制近端节点具有相似的嵌入)。从语言建模的角度来看,正样本通常对应于在图中短时间的随机游走中一起出现的节点对,有效地将节点视为单词,将随机游走视为句子。最近有的方法提出使用节点锚定采样作为替代。这些方法的负采样主要是基于随机对的抽样。
2.3 预测编码
对比预测编码(CPC)是另一种基于互信息最大化的深度表示的学习方法。CPC也是一种对比学习方法,它使用条件密度的估计(以噪声对比估计的形式)作为评分函数。然而,与DGI 不同的是,CPC是预测性的:对比目标有效地训练了输入的结构指定部分(例如,相邻节点对之间或节点与其邻居之间)之间的预测器。DGI 不同之处在于同时对比一个图的全局/局部部分,其中全局变量是从所有的局部变量计算出来的。
3 DGI Methodology
在本节中,我们将以自上而下的方式介绍DGI方法:首先是对我们特定的无监督学习设置的抽象概述,然后是对我们的方法优化的目标函数的阐述,最后是在单图设置中枚举我们过程的所有步骤。
3.1 基于图的无监督学习
我们假设一个通用的基于图的无监督机器学习设置:
首先给出一组节点特征,
X
=
{
x
1
⃗
,
x
2
⃗
,
.
.
.
,
x
N
⃗
}
X=\{ \vec{x_1}, \vec{x_2},..., \vec{x_N}\}
X={x1,x2,...,xN},其中
N
N
N是图中的节点数,
x
i
⃗
∈
R
F
\vec{x_i}∈\mathbb{R}^{F}
xi∈RF代表节点
i
i
i的特征表示。邻接矩阵
A
∈
R
N
×
N
A∈\mathbb{R}^{N×N}
A∈RN×N,在本文中默认所有处理的图是无权图,同时邻接矩阵存储的值为 0 或 1。
模型的目的是学习一个编码器, ε ε ε: R N × F × R N × N → R N × F ′ \mathbb{R}^{N×F}×\mathbb{R}^{N×N}→\mathbb{R}^{N×F'} RN×F×RN×N→RN×F′,可以形式化的表示为 E ( X , A ) = H = { h 1 ⃗ , h 2 ⃗ , . . . , h N ⃗ } \mathcal{E}(\pmb{X},\pmb{A})=\pmb{H}=\{ \vec{h_1}, \vec{h_2},..., \vec{h_N}\} E(XXX,AAA)=HHH={h1,h2,...,hN},其中 H H H代表高阶表示,并且每个节点 i i i满足 h i ⃗ ∈ R F ′ \vec{h_i}∈\mathbb{R}^{F'} hi∈RF′。所得到的节点特征的高阶表示可以用于各种下游任务,例如节点分类任务。
在这里,我们将重点讨论图卷积编码器,它通过不断聚合目标节点周边的邻居来完成特征学习。它所产生的 h i ⃗ \vec{h_i} hi总结了以节点 i i i为中心的图的一个patch,而不仅仅是节点本身 。在接下来的内容中,我们通常将 h i ⃗ \vec{h_i} hi称为patch representations来强调这一点。
3.2 局部-全局互信息最大化
DGI 的核心思想在于通过最大化局部互信息来训练编码器——也就是说,DGI寻求获得节点(即局部)表示,以捕获整个图的全局信息(表示为summary vector, s ⃗ \vec{s} s)。
为了得到图级别的summary vector s ⃗ \vec{s} s,作者提出了一种 readout 函数, R \mathcal{R} R: R N × F → R F \mathbb{R}^{N×F}→\mathbb{R}^{F} RN×F→RF,利用它将获得的patch representations总结为图级别的表示。上述过程可以总结为 s ⃗ = R ( E ( X , A ) ) \vec{s}=\mathcal{R}(\mathcal{E}(\pmb{X},\pmb{A})) s=R(E(XXX,AAA))。
作为最大化局部互信息的指标,我们使用了一个discriminator, D \mathcal{D} D: R F × R F → R \mathbb{R}^{F}×\mathbb{R}^{F}→\mathbb{R} RF×RF→R。这样, D ( h i ⃗ , s ⃗ ) \mathcal{D}(\vec{h_i},\vec{s}) D(hi,s)表示分配给这个patch-summary对的概率分数(对于包含在summary中的patch应该更高)。
D \mathcal{D} D的负样本由 ( X , A ) (\pmb{X},\pmb{A}) (XXX,AAA)的summary vector s ⃗ \vec{s} s与一个可选择的图 ( X ~ , A ~ ) (\widetilde{\pmb{X}},\widetilde{\pmb{A}}) (XXX ,AAA )的patch representations h ~ j ⃗ \vec{\widetilde{h}_j} h j提供。在多图的数据集中, ( X ~ , A ~ ) (\widetilde{\pmb{X}},\widetilde{\pmb{A}}) (XXX ,AAA )可以通过训练集的其他元素获得。但是,对于单个图,需要一个显式(随机)corruption function, C \mathcal{C} C: R N × F × R N × N → R M × F × R M × M \mathbb{R}^{N×F}×\mathbb{R}^{N×N}→\mathbb{R}^{M×F}×\mathbb{R}^{M×M} RN×F×RN×N→RM×F×RM×M来生成负样本的图 ( X ~ , A ~ ) (\widetilde{\pmb{X}},\widetilde{\pmb{A}}) (XXX ,AAA )。上述过程可以表述为 ( X ~ , A ~ ) = C ( X ~ , A ~ ) (\widetilde{\pmb{X}},\widetilde{\pmb{A}})=\mathcal{C}(\widetilde{\pmb{X}},\widetilde{\pmb{A}}) (XXX ,AAA )=C(XXX ,AAA )。
负样本抽样程序的选择将决定着作为这种最大化的副产品所希望捕获的具体结构信息的种类。
对于目标,我们遵循Deep InfoMax,使用带有标准二值交叉熵(BCE)损失的噪声对比型目标函数(正样本和负样本之间):
通过上式可以有效地最大化
h
i
⃗
\vec{h_i}
hi和
s
⃗
\vec{s}
s之间的互信息。
由于所有导出的patch representations都是为了保存与全局图总结的互信息,这允许在patch级别上发现和保存相似性——例如,具有相似结构的远距离节点(众所周知,这对于许多节点分类任务来说是一个强大的预测因素)。
3.3 理论动力
3.4 DGI概述
假设单图设置(即
(
X
,
A
)
(\pmb{X},\pmb{A})
(XXX,AAA)作为输入),DGI 的步骤:
- 通过corruption function得到负样本实例: ( X ~ , A ~ ) ∽ C ( X , A ) (\widetilde{\pmb{X}},\widetilde{\pmb{A}})\backsim \mathcal{C}(\pmb{X},\pmb{A}) (XXX ,AAA )∽C(XXX,AAA)。
- 通过编码器获得输入图的patch representations h i ⃗ \vec{h_i} hi: H = E ( X , A ) = { h 1 ⃗ , h 2 ⃗ , . . . , h N ⃗ } \pmb{H}=\mathcal{E}(\pmb{X},\pmb{A})=\{ \vec{h_1}, \vec{h_2},..., \vec{h_N}\} HHH=E(XXX,AAA)={h1,h2,...,hN}。
- 通过编码器获得负样本的patch representations h ~ j ⃗ \vec{\widetilde{h}_j} h j: H ~ = E ( X ~ , A ~ ) = { h ~ 1 ⃗ , h ~ 2 ⃗ , . . . , h ~ M ⃗ } \widetilde{\pmb{H}}=\mathcal{E}(\widetilde{\pmb{X}},\widetilde{\pmb{A}})=\{ \vec{\widetilde{h}_1},\vec{\widetilde{h}_2},..., \vec{\widetilde{h}_M}\} HHH =E(XXX ,AAA )={h 1,h 2,...,h M}。
- 通过 Readout 函数传递输入图的patch representations来得到图级别的summary vector: s ⃗ = R ( H ) \vec{s}=\mathcal{R}(\pmb{H}) s=R(HHH)。
- 通过梯度下降法最小化目标函数式(1),更新参数 E \mathcal{E} E、 R \mathcal{R} R、 D \mathcal{D} D。
4 实验
我们评估了DGI编码器在各种节点分类任务(直推式学习和归纳式学习)上学习的表示的好处,获得了有竞争力的结果。在每种情况下,DGI都被用来以完全无监督的方式学习patch representations,然后评估这些表示的节点级分类效用。这是通过直接使用这些表示来训练和测试一个简单的线性(逻辑回归)分类器来实现的。
4.1 数据集
(1)在Cora、Citeseer和Pubmed引文网络上对研究论文进行主题分类。
(2)以Reddit帖子为模型预测社交网络的社区结构。
(3)对蛋白质-蛋白质相互作用(PPI)网络中的蛋白质作用进行分类,需要对未见网络进行归纳。
4.2 实验设置
对于三个实验设置(直推式学习、大图上的归纳式学习和多图上的归纳式学习)中的每一个,我们使用了与该设置相适应的不同编码器和corruption function。
1. 直推式学习
编码器是一层图卷积网络(GCN)模型,具有以下传播规则:
其中,
A
^
=
A
+
I
N
\hat{A}=A+I_N
A^=A+IN代表加上自环的邻接矩阵,
D
^
\hat{D}
D^代表相应的度矩阵,满足
D
^
i
i
=
∑
j
A
^
i
j
\hat{D}_{ii}=\sum_j\hat{A}_{ij}
D^ii=∑jA^ij。对于非线性激活函数
σ
\sigma
σ,选择PReLU。
Θ
∈
R
F
×
F
′
\Theta∈R^{F×F'}
Θ∈RF×F′是应用于每个节点的可学习线性变换。
对于破坏函数 C C C,直接采用 A ~ = A \widetilde{A}=A A =A,但是 X ~ \widetilde{X} X 是由原本的特征矩阵 X X X经过随机变换得到的。也就是说,损坏的图由与原始图完全相同的节点组成,但它们位于图中的不同位置,因此将得到不同的邻近表示。
2. 大图上的归纳式学习
对于归纳学习,不再在编码器中使用GCN更新规则(因为学习的滤波器依赖于固定的和已知的邻接矩阵);相反,我们应用平均池( mean-pooling)传播规则,GraphSAGE-GCN:
D
^
−
1
\hat{D}^{-1}
D^−1实际上执行的是标准化的和(因此是 mean-pooling)。尽管式(4)明确指定了邻接矩阵和度矩阵,但并不需要它们:因为 Const-GAT 模型中使用的持续关注机制可以观察到相同的归纳行为。
对于Reddit数据库,DGI 的编码器是一个带有跳过连接的三层均值池模型:
由于数据集的规模很大,它将不能完全适合GPU内存。因此,采用子抽样方法,首先选择小批量的节点,然后,通过对具有替换的节点邻域进行抽样,得到以每个节点为中心的子图。具体来说,DGI 在第一层、第二层和第三层分别采样10、10和25个邻居,这样每个次采样的 patch 有1 + 10 + 100 + 2500 = 2611个节点。只进行了推导中心节点
i
i
i的 patch 表示
h
i
h_i
hi所必需的计算。这些表示然后被用来为minibatch(图2)导出总结向量
s
⃗
\vec{s}
s。在整个训练过程中使用了256个节点的 minibatch 。
为了在此设置中定义破坏函数,DGI 使用与在直推式学习中类似的方法,但将每个次采样的patch作为一个单独的要破坏的图(即在次采样的patch中按行随机打乱特征矩阵)。这很可能导致中心节点的特征被替换为抽样邻居的特征,进一步鼓励负样本的多样性。然后将在中心节点中获得的patch表示提交给鉴别器。
图2中,摘要向量 s ⃗ \vec{s} s是通过组合几个子采样的邻近表示 h i ⃗ \vec{h_i} hi得到的。
3. 多图上的归纳式学习
例如 PPI 数据集,编码器是一个带有密集跳过连接的三层均值池模型
其中,
W
s
k
i
p
W_{skip}
Wskip是一个可学习的投影矩阵。
在这个多图设置中,DGI 选择使用随机抽样的训练图作为负样本(即,DGI 的破坏函数只是从训练集中抽样一个不同的图)。作者发现该方法是最稳定的,因为该数据集中超过40%的节点具有全零特征。为了进一步扩大负样本池,作者还将dropout应用于采样图的输入特征。作者发现,在将学习到的嵌入信息提供给逻辑回归模型之前,将其标准化是有益的。
Readout,鉴别器的细节
在所有三个实验设置中,作者使用了相同的readout函数和discriminator体系结构。
对于 Readout Function,作者使用所有节点特征的简单平均值:
作者通过应用一个简单的双线性评分函数对图级别的summarize-patch representation对进行评分:
其中,
W
W
W是一个可学习的评分权重参数,
σ
\sigma
σ是逻辑Sigmoid非线性,用于将分数转换为
(
h
i
⃗
,
s
⃗
)
(\vec{h_i},\vec{s})
(hi,s)为正对的概率。
4.3 结果
参考博客
DEEP GRAPH INFOMAX
论文笔记:ICLR 2019 Deep Graph Infomax
Deep Graph Infomax(DGI) 论文阅读笔记
DEEP GRAPH INFOMAX 阅读笔记
【ICLR 2019论文】互信息最大化的无监督图神经网络Deep Graph Infomax