GRAPHAF:A FLOW-BASED AUTOREGRESSIVE MODEL FOR MOLECULAR GRAPH GENERATION
一、概要
设计具有所需特性的新型分子结构是诸如药物发现和材料科学等各种应用中的基本问题。.由于化学空间本质上是离散的,整个搜索空间是巨大的,大概有
1
0
33
10^{33}
1033这么大,所以这个问题是非常具有挑战性的。由于这些领域中的大量数据,机器学习技术在分子设计中看到了巨大的机遇。近年来,人们越来越致力于开发能够自动生成化学有效分子结构并优化其性质的机器学习算法。
对于自回归模型作者也有论述:自回归模型的迭代性质允许在生成过程中有效利用化学规则进行化合价检查,因此这些模型生成的有效分子的比例非常高。 但是,由于顺序生成的性质,训练过程通常很慢。 GraphAF方法具有像自动回归模型(从潜在空间到观察空间的映射)这样的迭代生成过程的优势,同时可以计算与前馈神经网络(从观察空间到潜在空间的映射)相对应的精确似然性。 通过并行计算有效地进行。
二、预备知识
2.1 Autoregressive flow
\;\;\;\;\;\;
N
o
r
m
a
l
i
z
i
n
g
f
l
o
w
\mathbf{Normalizing \;\;flow}
Normalizingflow为从基本分布
E
\mathcal E
E(可以是高斯分布)映射到真实世界可以观测到的空间
Z
\mathcal Z
Z(比如图片或者语音)的一个参数化的可逆的确定性变换。即:
ϵ
∼
p
E
(
ϵ
)
\epsilon\sim p_{\mathcal E}(\epsilon)
ϵ∼pE(ϵ)是基础分布,而
f
:
E
→
Z
f:\mathcal E\to\mathcal Z
f:E→Z是一个可逆变换。则真实世界中的分布
p
Z
(
z
)
p_{\mathbf Z}(z)
pZ(z)可由如下公式计算:
规范化流(
N
o
r
m
a
l
i
z
i
n
g
f
l
o
w
\mathbf{Normalizing \;\;flow}
Normalizingflow)作为生成模型的两个关键过程:
- p Z ( z ) p_{Z}(z) pZ(z)可以通过 f f f的逆变换来计算: ϵ = f θ − 1 ( z ) \epsilon =f_{\theta}^{-1}(\mathcal z) ϵ=fθ−1(z);
- 为了获得 z z z,计算方式如下:首先, f θ f_{\theta} fθ是一个可逆变换并且对应的雅可比行列式应该较为容易计算;接下来按照 ϵ ∼ p ( ϵ ) \epsilon\sim p(\epsilon) ϵ∼p(ϵ)进行采样,并执行前向变换 z = f θ ( ϵ ) z=f_{\theta}(\epsilon) z=fθ(ϵ),从而的到 z z z。
自回归流(AF)是满足这些条件的一种变体,其中包含三角Jacobian矩阵,并且行列式可以线性计算。给定
z
∈
R
D
z\in\mathbb R^{D}
z∈RD(
D
D
D是观测数据的维度),自回归条件概率可以参数化为高斯分布。
g
μ
g_{\mu}
gμ,
g
α
g_{\alpha}
gα是
z
1
:
d
−
1
z_{1:d-1}
z1:d−1的正标量函数,用来得到均值和偏差。在实际中,这些函数可以用神经网络来实现。由此可得:
\;\;\;\;\;\;
A
F
\mathbf{AF}
AF的雅克比矩阵是下三角的,因为
∂
z
i
∂
ϵ
j
\frac{\partial z_{i}}{\partial {\epsilon}_{j}}
∂ϵj∂zi仅仅对于
j
<
i
j<i
j<i是非零的。具体来说,要执行密度估算,我们可以并行应用所有单独的标量仿射变换来计算基本密度,每个密度都取决于先前的变量
z
1
:
d
−
1
z_{1:d-1}
z1:d−1,在刚开始我们可以先采样
ϵ
∈
R
D
\epsilon\in\mathbb R_{D}
ϵ∈RD并通过仿射参数计算
z
1
z_{1}
z1接下来再根据之前得到的
z
1
:
d
−
1
z_{1:d-1}
z1:d−1来计算
z
d
z_{d}
zd。
2.2 Graph Representation Learning
\;\;\;\;\;\;
作者将分子图表示为
G
=
(
A
,
X
)
,
A
\mathcal G=(\mathbf A,\mathbf X),\mathbf A
G=(A,X),A是伴随矩阵,
X
\mathbf X
X是节点特征矩阵。假设图中有
n
n
n个节点,
d
,
b
d,b
d,b分别代表着节点的种类的数量与边种类的数量。则有:
A
∈
{
0
,
1
}
n
×
n
×
d
,
X
∈
{
0
,
1
}
n
×
d
A\in\left\{0,1\right\}^{n\times n\times d},\;\;\;\mathbf X\in\left\{0,1\right\}^{n\times d}\;
A∈{0,1}n×n×d,X∈{0,1}n×d,并且当第
i
i
i个节点和第
j
j
j个节点之间以第
k
k
k种类型的边相连,则有
A
i
j
k
=
1
\mathbf A_{ijk}=1
Aijk=1。作者采用
R
−
G
C
N
\mathbf{R-GCN}
R−GCN来学习节点的表征向量。
k
k
k代表着节点嵌入向量的维度,我们定义第
l
l
l层
R
−
G
C
N
R-GCN
R−GCN的节点的嵌入矩阵为
H
l
∈
R
n
×
k
H^{l}\in\mathbb R^{n\times k}
Hl∈Rn×k:
符号含义:
- E i = A [ : , : , i ] \mathbf E_{i}=\mathbf A_{[:,:,i]} Ei=A[:,:,i];
- E ~ i = E i + I \mathbf{\tilde E_{i}}=\mathbf E_{i}+\mathbf I E~i=Ei+I;
- D ~ i = Σ k E ~ i [ j , k ] \mathbf{\tilde D_{i}}={\Sigma}_{k}\tilde E_{i}[j,k] D~i=ΣkE~i[j,k];
- W i ( l ) \mathbf W_{i}^{(l)} Wi(l):是第 l l l层 R − G C N \mathbf{R-GCN} R−GCN对于第 i i i种边类型的可以训练的权重矩阵。
- A g g ( ⋅ ) Agg(\cdot) Agg(⋅):表示聚合函数,例如求平均值或者相加求和。
- H 0 \mathbf H^{0} H0由节点特征矩阵 X \mathbf X X初始化,经过 L \mathbf L L层消息传递网络,得到表示矩阵 H L \mathbf H^{L} HL来作为最终的节点表征向量。同时,整个图的表征向量可以通过聚合最终的每个节点的表征向量来得到。
三、方法介绍
3.1 GraphAF架构
- 首先,给定一张空图 G 1 \mathbf G_{1} G1;
- 每一次迭代过程中,都会根据现有的子图结构添加一个新的节点即 p ( X i ∣ G i ) p(X_{i}|G_{i}) p(Xi∣Gi);
- 然后,根据当前图结构顺序生成新节点与现有节点之间的边,即 p ( A i j ∣ G i , X i , A i , 1 : j − 1 ) p(\mathbf A_{ij}|\mathbf G_{i},\mathbf X_{i},\mathbf A_{i,1:j-1}) p(Aij∣Gi,Xi,Ai,1:j−1);
- 重复上述过程,直到所有的节点和边都生成完毕,如图1(a)所示。
GraphAF旨定义一个从基本分布(比如多元高斯分布)到分子图结构 G = ( A , X ) \mathbf G=(\mathbf A,\mathbf X) G=(A,X)的可逆变换。注意,对于两个节点之间的边的类型,我们还要再加一项:“no edge”,即两个节点之间没有相连(连个原子之间没有形成化学键),则有: A ∈ { 0 , 1 } n × n × ( b + 1 ) \mathbf A\in\left\{0,1\right\}^{n\times n\times(b+1)} A∈{0,1}n×n×(b+1)。由于节点类型 X i \mathbf X_{i} Xi与边类型 A i j \mathbf A_{ij} Aij是离散的,这并不适合flow-based的模型。一种标准的解决方法是利用Dequantization技术,就是在离散数据上添加真是的噪声,对于图 G = ( A , X ) G=(A,X) G=(A,X)经过处理得到连续的表示 z = ( z A , z X ) z=(z^{A},z^{X}) z=(zA,zX)有:
对于图生成的条件分布为:
g μ X \mathbf g_{{}_{\mu}\!X} gμX, g μ A \mathbf g_{{}_{\mu}\!A} gμA, g α X \mathbf g_{{}_{\alpha}\!X} gαX, g α A \mathbf g_{{}_{\alpha}\!A} gαA是用来生成高斯分布均值以及偏差的被参数化的神经网络。给定子图结构 G i G_{i} Gi,我们用一个 L L L层的 R e l a t i o n a l − G C N \mathbf {Relational-GCN} Relational−GCN来学习到节点嵌入向量 H i L ∈ R n × k \mathbf H_{i}^{L}\in\mathbb R^{n\times k} HiL∈Rn×k,以及整个子图的表征嵌入 h ~ i ∈ R k \tilde h_{i}\in\mathbb R^{k} h~i∈Rk,并在此基础上定义高斯分布的均值和标准差,分别生成节点和边:
m μ X m_{{}_{\mu}\!X} mμX, m α X m_{{}_{\alpha}\!X} mαX,分别是用来根据现有的子图的embedding预测节点种类的 M L P \mathbf {MLP} MLP。 m μ A m_{{}_{\mu}\!A} mμA, m α A m_{{}_{\alpha}\!A} mαA,是用来根据现有的子图的embedding以及节点的embeddings来预测边的类型。为了生成一个新的节点以及它与现存节点相连的边,我们依据基本的高斯分布得到随机变量 ϵ i {\epsilon}_i ϵi, ϵ i j {\epsilon}_{ij} ϵij,即:
由此得到节点的类型和边的类型: - X i = v a r g m a x ( z i X ) d X_{i}=v^{d}_{argmax(z_{i}^{X})} Xi=vargmax(ziX)d;
- A i j = v a r g m a x ( z i j A ) b + 1 A_{ij}=v^{b+1}_{argmax(z_{ij}^{A})} Aij=vargmax(zijA)b+1
- 注: v q p v_{q}^{p} vqp表示一个 p p p维的 o n e − h o t one-hot one−hot向量,它的第 q q q个位置值为1;
ϵ
=
ϵ
1
,
ϵ
2
,
ϵ
21
,
ϵ
3
,
ϵ
31
,
.
.
.
,
ϵ
n
,
ϵ
n
1
,
.
.
.
,
ϵ
n
,
n
−
1
,
\epsilon={{\epsilon}_{1},{\epsilon}_{2},{\epsilon}_{21},{\epsilon}_{3},{\epsilon}_{31},...,{\epsilon}_{n},{\epsilon}_{n1},...,{\epsilon}_{n,n-1},}
ϵ=ϵ1,ϵ2,ϵ21,ϵ3,ϵ31,...,ϵn,ϵn1,...,ϵn,n−1,。GraphAF定义了一个基本的高斯分布
ϵ
{\epsilon}
ϵ与分子图结构
z
=
(
z
A
,
z
X
)
z=(z^{A},z^{X})
z=(zA,zX)。则反变换的可以由如下式子计算:
3.2 Efficient Parallel Training
在GraphAF中,因为变换
f
:
E
→
Z
f:\mathcal E\to\mathcal Z
f:E→Z是自回归的,则
f
−
1
:
Z
→
E
f^{-1}:\mathcal Z\to\mathcal E
f−1:Z→E的雅克比矩阵是下三角矩阵,并且它的行列式可以非常高效地计算。
在训练过程中,我们可以利用masking在输入分子图G和输出潜变量
ϵ
\epsilon
ϵ之间定义一个前馈神经网络来进行并行计算。当我们推断节点
i
i
i时,mask掉图的一些边来保证RGCN只处理子图
G
i
\mathbf G_i
Gi;当我们推断边的存在时,我们mask掉一些链接来保证RGCN只处理到
G
i
,
X
i
,
A
i
,
1
:
j
−
1
G_i,X_i,A_{i,1:j-1}
Gi,Xi,Ai,1:j−1。用这种masking的技巧,使得模型不仅满足了自回归的特性,而且能够通过并行计算所有条件,可以在一次向前传递中有效地计算
p
(
G
)
p(G)
p(G)。并且为了继续加速训练的过程,训练图的节点和边会根据BFS来被重新排序。由于BFS的性质,键只能存在于相同或连续的BFS深度内的节点之间。因此,节点间的最大依赖距离由单个BFS深度中最大的节点数所限定。在我们的数据集中,任何单个的BFS深度包含的节点不超过12个,这意味着我们只需要对当前原子和最近生成的12个原子之间的边缘进行建模。
3.3 Validity Constrained Sampling
化学种存在着很多化学规则,这能够帮助生成有效分子。由于我们的方法采取了顺序生成的过程,所以可以在每个生成步骤中利用这些规则。我们可以在采样过程中直接利用化合价约束来检查现有的化学键的数量是否满足这个约束。我的理解就是比如一个原子按照化学规则它最多能有四个与其他原子相连的化学键,但现在却有五个,那这个就不符合化学规则,需要重新采样,重新生成。
并且当以下两种条件满足时,图的生成过程就停止了:
- 图的大小达到最大值n;
- 新生成的原子和之前的子图之间没有键。最后,氢被加到没有填满价键的原子上。
3.4 Goal-Directed Molecule Generation with Reinforcement Learning
到目前为止,我们已经介绍了如何使用GraphAF对分子图结构的数据密度进行建模并生成有效分子。 但是,对于药物发现,我们还需要优化生成分子的化学性质。 在这一部分中,我们将介绍如何通过强化学习来微调我们的生成过程,以优化生成的分子的特性。
主要介绍一下Reward design,主要有阶段性的奖励以及最终性的奖励。阶段性的奖励为如果边缘预测违反了价检查,将引入一个小惩罚作为中间奖励。最终的奖励包括生成分子的目标特性的分数,如辛醇-水分配系数或者类药性,化学有效性奖励,例如对空间应变过大的分子和/或违反ZINC官能团过滤器的官能团的处罚。实际中,我们使用PPO来训练GraphAF
四、实验
4.1 实验设置
一共有三个任务:
- D e n s i t y M o d e l i n g a n d G e n e r a t i o n \mathbf{Density\;Modeling \;and \;Generation} DensityModelingandGeneration:评估模型的能力,以学习数据分布和产生现实和多样化的分子。Validity是生成的所有图中有效分子的百分比;Uniqueness是生成的分子中唯一分子的百分比。Novelty是生成的分子中没有出现在训练集中的百分比,Reconstruction是可以从潜在的载体中重建的分子的百分比。我们从10000个随机生成的分子中计算出上述指标。
- P r o p e r t y O p t i m i z a t i o n \mathbf{Property\;Optimization} PropertyOptimization:专注于生成具有优化化学性能的新型分子。 对于此任务,我们对从密度建模任务预训练的网络进行微调,以最大化所需的属性。在这项任务中,我们的目标是生成具有理想性质的分子。具体来说,我们选择惩罚logP和QED作为我们的目标属性。前者的分数是logP分数,受环大小和合成可及性的影响,而后者则测量分子的药物相似性。请注意,这两个分数都是使用经验预测模型计算的,作者采用了这篇文章中使用的脚本,以使结果具有可比性。为了完成这项任务,作者对GraphAF网络进行300个epoch的预训练,以进行似然建模,然后应用第3.4节中描述的 R L RL RL过程来微调网络,使其达到所需的化学特性。如下所示,GraphAF在惩罚logP得分方面的表现优于所有基线,并在QED方面取得了可比的结果。这一现象表明,结合RL过程,GraphAF成功地捕捉到了所需分子的分布。
- C o n s t r a i n e d P r o p e r t y O p t i m i z a t i o n \mathbf{Constrained\;Property\;Optimization} ConstrainedPropertyOptimization:其目的是在满足相似性约束的前提下,对给定的分子进行修饰以改善期望的性质。约束条件是原始分子和修改后的分子之间的相似性高于一个阈值。与性能优化任务类似,我们通过密度模型对GraphAF进行预处理,然后用RL对模型进行微调。在生成过程中,我们将初始状态设置为从800个分子中随机抽取的子图进行优化。为了进行评估,我们在表6中报告了最高改进的平均值和标准差以及原始分子和改性分子之间的相应相似性。实验结果表明,GraphAF明显优于以往的所有方法,并且几乎总能成功地改善目标特性。图2(c)展示了两个优化实例,表明我们的模型能够大幅度提高惩罚logP得分,同时保持原分子和改性分子之间的高度相似性
数据集: Z I N C 250 k {ZINC250k} ZINC250k,该数据集包含25万个类药物分子,最大原子数为38个。它具有9种原子类型和3种边类型。 我们使用开源化学软件RDkit对分子进行预处理,所有分子均以烷基化形式存在,并且去除了氢。
4.2 Numerical Results
作者通过利用广泛使用的指标来评估所提出的方法对真实分子建模的能力:对于提出的四种度量方式,GraphAF的结果如下:
并且作者为了验证GraphAF并非在
Z
I
N
C
250
k
\mathbf{ZINC250k}
ZINC250k上过拟合。作者在另外两个数据集
Q
M
9
\mathbf{QM9}
QM9,
M
O
S
E
S
\mathbf{MOSES}
MOSES上进行了实验。
Q
M
9
\mathbf{QM9}
QM9含有134k分子和9个重原子,而MOSES要大得多,也更具挑战性,它包含190万个分子和30个重原子。表3显示,即使在更复杂的数据集上,GraphAF也总是能够生成有效且新颖的分子。此外,尽管GraphAF最初是为分子图生成而设计的,但实际上它非常通用,可以通过简单地修改节点和边缘生成函数Edge-MLP和Node-MLP来用于对不同类型的图进行建模。按照GNF实验的设定,作者在两个通用的图形数据集利用GraphAF进行了实验。
C
O
M
M
U
N
I
T
Y
−
S
M
A
L
L
\mathbf{COMMUNITY-SMALL}
COMMUNITY−SMALL是包含100个2社区图的综合数据集,而
E
G
O
−
S
M
A
L
L
\mathbf{EGO-SMALL}
EGO−SMALL是从
C
i
t
e
s
e
e
r
\mathbf{Citeseer}
Citeseer数据集中提取的一组图。