1 摘要
分子构象生成旨在生成分子中所有原子的三维坐标,是生物信息学和药理学中的一项重要任务。以前的方法通常首先预测分子的原子间距离、原子间距离的梯度或局部结构(如扭转角),然后重建其三维构象。如何在没有上述中间值的情况下直接生成构象还没有得到充分的探索。在这项工作中,我们提出了一种直接预测原子坐标的方法:(1)损失函数对坐标的旋转平移和对称原子的排列是不变的;(2) 新提出的模型自适应地聚合键和原子信息,并迭代地细化生成的构象的坐标。我们的方法在GEOM-QM9和GEOM Drugs数据集上获得了最佳结果。进一步的分析表明,我们生成的构象与基态构象具有更接近的性质(例如,HOMO-LUMO间隙)。此外,我们的方法通过提供更好的初始构象来改善分子对接。所有的结果都证明了我们方法的有效性和直接方法的巨大潜力。
代码: https://github.com/DirectMolecularConfGen/DMCG
论文地址: https://arxiv.org/abs/2202.01356
2 简介
分子构象生成旨在生成分子的3D原子坐标。虽然分子构象可以通过实验获得,例如通过X射线晶体学,对于工业规模的任务来说,成本高得令人望而却步。
近年来,机器学习方法由于其准确性和效率而在构象生成方面备受关注。以前的大多数方法首先预测一些中间值,如原子间距离、原子间距离的梯度或扭转角,然后基于它们重建构象。虽然这些方法改善了分子构象的生成,但它们的缺点在于中间值应该满足额外的硬约束。不幸的是,在许多情况下并不能满足这些约束。例如,GraphDG预测了原子间距离,然后基于它们重建构象,其中实际距离(例如,考虑三个原子)应该满足三角形不等式,但作者研究表明GraphDG预测的距离在8.65%的情况下违反了不等式。另一个例子是,ConfGF预测了原子间距离的梯度,平方距离矩阵的秩最多为5。这种约束使得梯度定义不清,因为在对特定距离dij进行无穷小的改变时,其他距离不能全部保持不变。
在没有这些中间值的情况下直接生成坐标是一种更直接的策略,但尚未完全探索。AlphaFold 2是一种直接的方法,在蛋白质结构预测方面取得了显著的性能。AlphaFold 2的成功激励了作者直接生成分子构象坐标方法。
这种方法的一个挑战是保持旋转平移不变性和置换不变性。具体来说,(1)旋转和平移所有原子作为一个基团的坐标不会改变分子的构象(2) 与对称性有关的原子应该考虑置换不变性。例如,如图1所示,由于嘧啶部分沿着C-S键(原子11和12)的对称性,原子13、14和原子17、16是等价的。
问题定义
- G = ( V , E ) G=(V,E) G=(V,E) ——分子图,其中V和E分别是原子和键的集合。具体地说, V = { v 1 , v 2 , ⋯ , v ∣ V ∣ } V=\{v_{1},v_{2},\cdots,v_{|V|}\} V={v1,v2,⋯,v∣V∣}, v i v_{i} vi表示第i个原子。 e i j e_{ij} eij表示原子 v i v_{i} vi和 v j v_{j} vj之间的键。
- N ( i ) N(i) N(i) ——原子邻节点,即 N ( i ) = { j ∣ ( i , j ) ∈ E } N(i)=\{j\mid(i,j)\in E\} N(i)={j∣(i,j)∈E}。
- R R R —— G G G的构象,其中 R ∈ R ∣ V ∣ × 3 R\in\mathbb{R}^{|V|{\times}3} R∈R∣V∣×3。 R R R的第i行 R i R_i Ri是原子 v i v_{i} vi的坐标。
我们的任务是学习一个映射,给定图 G = ( V , E ) G=(V,E) G=(V,E),它可以输出 V V V中所有原子的坐标 R R R,即分子构象。
3 模型框架
3.1 损失函数
构象的旋转平移和置换不变性
设
R
∈
R
∣
V
∣
×
3
R\in\mathbb{R}^{|V|\times3}
R∈R∣V∣×3和
R
^
∈
R
∣
V
∣
×
3
\hat{R}\in\mathbb{R}^{|V|\times3}
R^∈R∣V∣×3表示基本构象和生成的构象。旋转平移和置换不变损失定义如下:
ℓ
R
T
P
(
R
,
R
^
)
=
min
ρ
;
σ
∈
S
∥
R
−
ρ
(
σ
(
R
^
)
)
∥
F
2
(
1
)
\ell_{\mathrm{RTP}}(R,\hat{R})=\min\limits_{\rho;\sigma\in\mathcal{S}}\|R-\rho(\sigma(\hat{R}))\|_F^2\quad(1)
ℓRTP(R,R^)=ρ;σ∈Smin∥R−ρ(σ(R^))∥F2(1) 上述损失函数中,
ρ
\rho
ρ表示旋转-平移操作。
S
\mathcal{S}
S表示对称原子上的置换运算的集合。例如,在图1中,S包含两个元素
σ
1
\sigma_1
σ1和
σ
2
\sigma_{2}
σ2,其中
σ
1
\sigma_{1}
σ1是一个相同的映射,即对于任何
i
∈
{
1
,
2
,
⋯
,
18
}
i\in\{1,2,\cdots,18\}
i∈{1,2,⋯,18},
σ
1
(
i
)
=
i
\sigma_{1}(i)=i
σ1(i)=i。
σ
2
\sigma_{2}
σ2是嘧啶对称原子上的映射,即
σ
2
(
13
)
=
17
\sigma_{2}(13)=17
σ2(13)=17,
σ
2
(
17
)
=
13
\sigma_{2}(17)=13
σ2(17)=13;
σ
2
(
14
)
=
16
\sigma_{2}(14)=16
σ2(14)=16,
σ
2
(
16
)
=
14
\sigma_{2}(16)=14
σ2(16)=14,对于剩余的原子i则
σ
2
(
i
)
=
i
\sigma_{2}(i)=i
σ2(i)=i。
总之,公式(1)将
R
R
R和
R
^
\hat{R}
R^的损失定义为:在对称原子的任何旋转平移操作和任何置换操作下目标构象的距离和。其可以通过四元组和图同构来求解。
实际上,为了求解方程。公式(1)被优化分解为两个子问题:
ℓ
R
T
=
min
ρ
∥
ρ
(
R
^
)
−
R
∥
F
2
(
S
1
)
\ell_{\mathrm{RT}}=\min_{\rho}\|\rho(\hat{R})-R\|_F^2\quad(S1)
ℓRT=ρmin∥ρ(R^)−R∥F2(S1)
ℓ
P
=
min
σ
∈
S
∥
σ
(
R
^
)
−
R
∥
F
2
(
S
2
)
\ell_{\mathrm{P}}=\operatorname{min}_{\sigma\in{\mathcal{S}}}\|\sigma({\hat{R}})-R\|_{F}^{2}\quad(S2)
ℓP=minσ∈S∥σ(R^)−R∥F2(S2)
S1使用四元组求解。四元组是复数的扩展。对于四元数,任何旋转运算都是由3×3的旋转矩阵矩阵指定的。S1的解是通过对
R
R
R和
R
^
\hat{R}
R^的代数运算获得的4×4矩阵的最小特征值。为了稳定训练,在代码中通过ρ停止梯度反向传播。
S2使用图同构求解。我们需要找到S中的所有元素,然后枚举它们以获得最小值,将寻找S转化为分子图上的图同构问题。
构象的一对多映射
一个分子可能对应于多种构象。因此,我们在模型中引入了一个随机变量z,用于不同构象的生成。分子图G相同,不同的z可能导致不同的构象,表示为
R
^
(
z
,
G
)
\hat{R}(z,G)
R^(z,G)。
受变分自动编码器(VAE)的启发,作者引入了一个(条件)推理模型
q
(
z
∣
R
,
G
)
q(z|R,G)
q(z∣R,G)来描述z的后验分布,以改进重建损失,即
E
q
(
z
∣
R
,
G
)
[
ℓ
R
T
P
(
R
,
R
^
(
z
,
G
)
)
]
\mathbb{E}_{q(z|R,G)}\left[\ell_{\mathrm{RTP}}(R,\hat{R}(z,G))\right]
Eq(z∣R,G)[ℓRTP(R,R^(z,G))]。同时计算KL散度作为正则化项,即
D
K
L
(
q
(
z
∣
R
,
G
)
∣
p
(
z
)
)
D_{\mathrm{KL}}(q(z|R,G)|p(z))
DKL(q(z∣R,G)∣p(z))。因此,聚集的(即平均/边缘化的)后验
∫
q
(
z
∣
R
,
G
)
p
data
(
R
)
d
R
\int q(z|R,G)p_{\text{data}}(R)\mathrm{d}R
∫q(z∣R,G)pdata(R)dR向前验
p
(
z
)
p(z)
p(z)趋近,这反过来允许具有
p
(
z
)
p(z)
p(z)样本的解码器从
p
data
(
R
)
p_{\text{data}}(R)
pdata(R)生成新的构象。
通过适当选择
q
(
z
∣
R
,
G
)
q(z|R,G)
q(z∣R,G),我们可以对损耗进行优化。我们将
q
(
z
∣
R
,
G
)
q(z|R,G)
q(z∣R,G)定义为
N
(
z
∣
μ
R
,
G
,
Σ
R
,
G
)
\mathcal{N}(z|\mu_{R,G},\Sigma_{R,G})
N(z∣μR,G,ΣR,G),具有对角协方差矩阵的多元高斯分布。它通过重新参数化实现了可处理的损失优化:
z
∼
q
(
z
∣
R
,
G
)
z \sim q(z|R,G)
z∼q(z∣R,G),相当于
z
(
i
)
=
μ
R
,
G
(
i
)
+
Σ
R
,
G
(
i
,
i
)
ϵ
,
∀
i
,
w
h
e
r
e
z^{(i)}=\mu_{R,G}^{(i)}+\sqrt{\Sigma_{R,G}^{(i,i)}}\epsilon,\forall i,\mathrm{where}
z(i)=μR,G(i)+ΣR,G(i,i)ϵ,∀i,where
ϵ
∼
N
(
0
,
1
)
\epsilon\sim\mathcal{N}(0,1)
ϵ∼N(0,1),
z
(
i
)
z^{(i)}
z(i)和
μ
R
,
G
(
i
)
\mu_{R,G}^{(i)}
μR,G(i)是
z
z
z和
μ
R
,
G
\mu_{R,G}
μR,G的第i个元素,
Σ
R
,
G
(
i
,
i
)
\Sigma_{R,G}^{(i,i)}
ΣR,G(i,i)是第i个对角线元素。KL损失被转化为
D
K
L
(
N
(
μ
R
,
G
,
Σ
R
,
G
)
∣
∥
N
(
0
,
I
)
)
D_{\mathrm{KL}}(\mathcal{N}(\mu_{R,G},\Sigma_{R,G})|\|\mathcal{N}(0,\boldsymbol{I}))
DKL(N(μR,G,ΣR,G)∣∥N(0,I)),它具有闭合形式的解。
最终,loss定义如下:
min
E
ϵ
∼
N
(
0
,
I
)
ℓ
R
T
P
(
R
,
R
^
(
z
,
G
)
)
+
β
D
K
L
(
N
(
μ
R
,
G
,
Σ
R
,
G
)
∣
∣
N
(
0
,
I
)
)
(
2
)
\min\operatorname*{E}_{\epsilon\sim\mathcal{N}(0,I)}\ell_{\mathrm{RTP}}(R,\hat{R}(z,G))+\beta D_{\mathrm{KL}}(\mathcal{N}(\mu_{R,G},\Sigma_{R,G})||\mathcal{N}(0,I))\quad(2)
minϵ∼N(0,I)EℓRTP(R,R^(z,G))+βDKL(N(μR,G,ΣR,G)∣∣N(0,I))(2) 其中
β
>
0
\beta>0
β>0是一个超参数。
3.2 训练过程
图2是训练流程和采样流程的流程图。
- 编码器
φ
2
D
\varphi_{2\mathrm{D}}
φ2D
将分子图G作为其输入,并输出几个表示,d是表示的维度。
H V ( 0 ) ∈ R ∣ V ∣ × d H_V^{(0)}\in\mathbb{R}^{|V|\times d} HV(0)∈R∣V∣×d:所有原子; H E ( 0 ) ∈ R ∣ E ∣ × d H_E^{(0)}\in\mathbb{R}^{|E|\times d} HE(0)∈R∣E∣×d:所有键
U ( 0 ) ∈ R d U^{(0)}\in\mathbb{R}^d U(0)∈Rd:全局图特征 R ^ ( 0 ) ∈ R ∣ V ∣ × 3 \hat{R}^{(0)}\in\mathbb{R}^{|V|\times3} R^(0)∈R∣V∣×3:初始构象
形式上: φ 2D ( G ) = ( H V ( 0 ) , H E ( 0 ) , U ( 0 ) , R ^ ( 0 ) ) \varphi_{\text{2D}}(G)=(H_V^{(0)},H_E^{(0)},U^{(0)},\hat{R}^{(0)}) φ2D(G)=(HV(0),HE(0),U(0),R^(0)) - 编码器
φ
3
D
\varphi_{3\mathrm{D}}
φ3D
提取构象R的特征,用于构建条件推理模块 q ( z ∣ R , G ) q(z|R,G) q(z∣R,G)。 φ 3 D \varphi_{3\mathrm{D}} φ3D输出高斯均值和协方差。
形式上: φ 3 D ( R , G ) = ( μ R , G , Σ R , G ) \varphi_{\mathrm{3D}}(R,G)=(\mu_{R,G},\Sigma_{R,G}) φ3D(R,G)=(μR,G,ΣR,G) - 前向过程
前向过程如图2(a)所示。我们采样一个随机变量z, z ∼ N ( μ R , G , Σ R , G ) z \sim \mathcal{N}(\mu_{R,G},\Sigma_{R,G}) z∼N(μR,G,ΣR,G)。然后将 H V ( 0 ) H_V^{(0)} HV(0)、 H E ( 0 ) H_E^{(0)} HE(0)、 U ( 0 ) U^{(0)} U(0)、 R ^ ( 0 ) \hat{R}^{(0)} R^(0)和z输入解码器 φ 3 D \varphi_{3\mathrm{D}} φ3D中得到分子构象 R ^ ( z , G ) \hat{R}(z,G) R^(z,G)。也就是说: φ d e c ( φ 2 d ( G ) , z ) = φ d e c ( H V ( 0 ) , H E ( 0 ) , U ( 0 ) , R ^ ( 0 ) , z ) \varphi_{\mathrm{dec}}(\varphi_{\mathrm{2d}}(G),z)=\varphi_{\mathrm{dec}}(H_{V}^{(0)},H_{E}^{(0)},U^{(0)},\hat{R}^{(0)},z) φdec(φ2d(G),z)=φdec(HV(0),HE(0),U(0),R^(0),z) 其中采样z可以通过先采样 ϵ ∼ N ( 0 , I ) \epsilon\sim\mathcal{N}(0,I) ϵ∼N(0,I),然后使 z ( i ) = μ R , G ( i ) + Σ R , G ( i , i ) ϵ z^{(i)}=\mu_{R,G}^{(i)}+\sqrt{\Sigma_{R,G}^{(i,i)}}\epsilon z(i)=μR,G(i)+ΣR,G(i,i)ϵ。
获得 R ^ ( z , G ) \hat{R}(z,G) R^(z,G)和 N ( μ R , G , Σ R , G ) \mathcal{N}(\mu_{R,G},\Sigma_{R,G}) N(μR,G,ΣR,G)后,我们通过优化损失函数进行训练。 - 推理过程
推理过程如图2(b)所示。我们使用已经训练好了的编码器 φ 2 D \varphi_{2\mathrm{D}} φ2D和解码器 φ d e c \varphi_{\mathrm{dec}} φdec。给定一个分子图 G G G,使用编码器 φ 2 D \varphi_{2\mathrm{D}} φ2D得到 H V ( 0 ) H_V^{(0)} HV(0)、 H E ( 0 ) H_E^{(0)} HE(0)、 U ( 0 ) U^{(0)} U(0)、 R ^ ( 0 ) \hat{R}^{(0)} R^(0)。接着,从 N ( 0 , I ) \mathcal{N}(0,I) N(0,I)中采样z,最后将 H V ( 0 ) H_V^{(0)} HV(0)、 H E ( 0 ) H_E^{(0)} HE(0)、 U ( 0 ) U^{(0)} U(0)、 R ^ ( 0 ) \hat{R}^{(0)} R^(0)和z输入解码器 φ d e c \varphi_{\mathrm{dec}} φdec中得到分子构象 R ^ ( z , G ) \hat{R}(z,G) R^(z,G)。
注意推理过程中没有用到编码器 φ 3 D \varphi_{3\mathrm{D}} φ3D。
3.3 模型架构
编码器
φ
2
D
\varphi_{2\mathrm{D}}
φ2D,
φ
3
D
\varphi_{3\mathrm{D}}
φ3D和解码器
φ
d
e
c
\varphi_{\mathrm{dec}}
φdec共享相同的架构。它们都堆叠了L个相同的块。我们以解码器
φ
d
e
c
\varphi_{\mathrm{dec}}
φdec为例,介绍其第l个块。
图3展示了第l个块的架构。该架构接受来自上一层的输出
R
^
(
l
−
1
)
,
H
V
(
l
−
1
)
,
H
E
(
l
−
1
)
,
U
(
l
−
1
)
\hat{R}^{(l-1)},H_{V}^{(l-1)},H_{E}^{(l-1)},U^{(l-1)}
R^(l−1),HV(l−1),HE(l−1),U(l−1),同时输出新的
R
^
(
l
)
,
H
V
(
l
)
,
H
E
(
l
)
,
U
(
l
)
\hat{R}^{(l)},H_{V}^{(l)},H_{E}^{(l)},U^{(l)}
R^(l),HV(l),HE(l),U(l),一直重复直到输出最终的构象
R
^
(
L
)
\hat{R}^{(L)}
R^(L)。其第一层的输入
R
^
(
0
)
,
H
V
(
0
)
,
H
E
(
0
)
,
U
(
0
)
\hat{R}^{(0)},H_{V}^{(0)},H_{E}^{(0)},U^{(0)}
R^(0),HV(0),HE(0),U(0)来源于编码器
φ
2
D
\varphi_{2\mathrm{D}}
φ2D的输出。
作者使用GN块的变体作为模型的主干。在每个区块中,我们首先更新键表示,然后更新原子表示,最后更新全局分子表示和构象。
从数学上讲,设
h
i
(
l
)
h_{i}^{(l)}
hi(l)表示第l个块输出的原子i的表示,
h
i
j
(
l
)
h_{ij}^{(l)}
hij(l)表示原子i和j之间的键的表示,MLP表示前馈网络。则第l个块实行以下操作:
- 更新键表示
h ˉ i ( l ) = h i ( l − 1 ) + M L P ( R ^ i ( l − 1 ) ) + z , ∀ i ∈ V ( 3 ) h ˉ i j ( l ) = h i j ( l − 1 ) + M L P ( ∥ R ^ i ( l − 1 ) − R ^ j ( l − 1 ) ∥ ) , ∀ ( i , j ) ∈ E \begin{aligned} &\bar{h}_i^{(l)} =h_i^{(l-1)} + \mathtt{MLP}(\hat{R}_i^{(l-1)}) + z, \forall i\in V& (3) \\ &\bar{h}_{ij}^{(l)} =h_{ij}^{(l-1)}+\mathtt{MLP}(\|\hat{R}_i^{(l-1)}-\hat{R}_j^{(l-1)}\|),\forall(i,j)\in E \end{aligned} hˉi(l)=hi(l−1)+MLP(R^i(l−1))+z,∀i∈Vhˉij(l)=hij(l−1)+MLP(∥R^i(l−1)−R^j(l−1)∥),∀(i,j)∈E(3) 其中, z ∼ N ( μ R , G , Σ R , G ) z \sim \mathcal{N}(\mu_{R,G},\Sigma_{R,G}) z∼N(μR,G,ΣR,G)。通过上面两个式子,将原子坐标信息融合到键的表示中。
随后,键表示更新如下:
h i j ( l ) = h i j ( l − 1 ) + M L P ( h ˉ i ( l − 1 ) , h ˉ j ( l − 1 ) , h ˉ i j ( l − 1 ) , U ( l − 1 ) ) ( 4 ) h_{ij}^{(l)}=h_{ij}^{(l-1)}+\mathtt{MLP}(\bar{h}_{i}^{(l-1)},\bar{h}_{j}^{(l-1)},\bar{h}_{ij}^{(l-1)},U^{(l-1)})\quad(4) hij(l)=hij(l−1)+MLP(hˉi(l−1),hˉj(l−1),hˉij(l−1),U(l−1))(4) - 更新原子表示
h ~ i ( l ) = ∑ j ∈ W ( i ) α j W v c o n c a t ( h ˉ i j ( i ) , h ˉ j ( i − 1 ) ) w h e r e α j ∝ exp ( a ⊤ ζ ( W q h ˉ i ( l − 1 ) + W k c o n c a t ( h ˉ j ( l − 1 ) , h ˉ i j l ) ) ) ( 5 ) h i ( l ) = h i ( l − 1 ) + M L P ( h ˉ i ( l − 1 ) , h ~ i ( l ) , U ( l − 1 ) ) . \begin{aligned} &\tilde{h}_{i}^{(l)} =\sum_{j\in W(i)}\alpha_{j}W_{v}\mathrm{concat}(\bar{h}_{i j}^{(i)},\bar{h}_{j}^{(i-1)})~\mathrm{where}~\alpha_{j}\propto\exp(\mathbf{a}^{\top}\zeta(W_{q}\bar{h}_{i}^{(l-1)}+W_{k}\mathrm{concat}(\bar{h}_{j}^{(l-1)},\bar{h}_{i j}^{l})))& \left(5\right) \\ &h_i^{(l)} =h_i^{(l-1)}+\mathtt{MLP}\Big(\bar{h}_i^{(l-1)},\tilde{h}_i^{(l)},U^{(l-1)}\Big). \end{aligned} h~i(l)=j∈W(i)∑αjWvconcat(hˉij(i),hˉj(i−1)) where αj∝exp(a⊤ζ(Wqhˉi(l−1)+Wkconcat(hˉj(l−1),hˉijl)))hi(l)=hi(l−1)+MLP(hˉi(l−1),h~i(l),U(l−1)).(5)
其中, a , W q , W v \boldsymbol{a},W_q,W_v a,Wq,Wv和 W k W_k Wk是需要学习的参数,concat()是两个向量的级联, ζ \zeta ζ是ReLU激活。对于原子 v i v_{i} vi,我们首先使用GATv2来聚合其连接键的表示,以获得 h ~ i \tilde{h}_{i} h~i,随后基于 h ~ i ( l ) , h ˉ i ( l − 1 ) , U ( l − 1 ) \tilde{h}_{i}^{(l)},\bar{h}_{i}^{(l-1)},U^{(l-1)} h~i(l),hˉi(l−1),U(l−1)来更新 v i v_{i} vi。 - 更新全局表示
U ( l ) = U ( l − 1 ) + M L P ( 1 ∣ V ∣ ∑ i = 1 ∣ V ∣ h i ( l ) , 1 ∣ E ∣ ∑ i , j h i j ( l ) , U ( l − 1 ) ) ( 6 ) U^{(l)}=U^{(l-1)}+\mathtt{MLP}\Big(\frac{1}{|V|}\sum_{i=1}^{|V|}h_i^{(l)},\frac{1}{|E|}\sum_{i,j}h_{ij}^{(l)},U^{(l-1)}\Big)\quad(6) U(l)=U(l−1)+MLP(∣V∣1i=1∑∣V∣hi(l),∣E∣1i,j∑hij(l),U(l−1))(6) - 更新分子构象
R ˉ i ( l ) = M L P ( h i ( l ) ) , m ( l ) = 1 ∣ V ∣ ∑ j = 1 ∣ V ∣ R ˉ j ( l ) , R ^ i ( l ) = R ˉ i ( l ) − m ( l ) + R ^ i ( l − 1 ) ( 7 ) \begin{aligned} & \\&\bar{R}_{i}^{(l)}=\mathtt{MLP}(h_{i}^{(l)}),\quad m^{(l)}& =\frac{1}{|V|}\sum_{j=1}^{|V|}\bar{R}_{j}^{(l)},\quad\hat{R}_{i}^{(l)}=\bar{R}_{i}^{(l)}-m^{(l)}+\hat{R}_{i}^{(l-1)}&& (7) \end{aligned} Rˉi(l)=MLP(hi(l)),m(l)=∣V∣1j=1∑∣V∣Rˉj(l),R^i(l)=Rˉi(l)−m(l)+R^i(l−1)(7)
上述方程中的一个重要步骤是在进行初始预测 R ˉ i ( l ) \bar{R}_i^{(l)} Rˉi(l)后,我们需要计算其中心,并通过将中心移动到原点来归一化它们的坐标。这种规范化可确保每个块生成的坐标在合理的数值范围内。
我们使用解码器 φ d e c \varphi_{\mathrm{dec}} φdec最后一个区块输出的 R ^ ( L ) \hat{R}^{(L)} R^(L)作为构象的最终预测。
4 实验
4.1 模型结构
所有的编码器和解码器都有6个块,特征的维度为256。受Transformer中前馈层的启发,MLP由两个子层组成,其中第一个子层将输入特征从256维映射到隐藏状态,然后是批量归一化和ReLU激活。然后,使用线性映射将隐藏状态再次映射到256。
考虑到我们的方法在每个块l处输出构象
R
^
(
l
)
\hat{R}^{(l)}
R^(l),我们希望每个
R
^
(
l
)
\hat{R}^{(l)}
R^(l)都与实际的构象相似。因此损失函数(2)实现为:
ℓ
R
T
P
(
R
^
(
L
)
,
R
)
+
λ
∑
l
=
0
L
−
1
ℓ
R
T
P
(
R
^
(
l
)
,
R
)
(
8
)
\ell_{\mathrm{RTP}}(\hat{R}^{(L)},R)+\lambda\sum_{l=0}^{L-1}\ell_{\mathrm{RTP}}(\hat{R}^{(l)},R)\quad(8)
ℓRTP(R^(L),R)+λl=0∑L−1ℓRTP(R^(l),R)(8)
4.1 评价
假设在测试集中,分子x具有
N
x
N_{x}
Nx构象。对于测试集中的每个分子x,我们生成2
N
x
N_{x}
Nx个构象。设
S
g
\mathbb{S}_g
Sg和
S
r
\mathbb{S}_r
Sr分别表示所有生成构象和基本构象。
我们使用覆盖率得分(COV)和匹配得分(MAT)来评估生成质量。为了测量
R
R
R和
R
^
\hat{R}
R^之间的差异,在RDKit包中使用GetBestRMS,并将均方根偏差表示为
RMSD
(
R
,
R
^
)
\text{RMSD}(R,\hat{R})
RMSD(R,R^)。基于召回的覆盖率和匹配分数定义如下:
C
O
V
(
S
g
,
S
r
)
=
1
∣
S
r
∣
∣
{
R
∈
S
r
∣
R
M
S
D
(
R
,
R
^
)
<
δ
,
∃
R
^
∈
S
g
}
∣
(
9
)
M
A
T
(
S
g
,
S
r
)
=
1
∣
S
r
∣
∑
R
∈
S
r
min
R
^
∈
S
g
R
M
S
D
(
R
,
R
^
)
\begin{gathered} \mathrm{COV}(\mathbb{S}_{g},\mathbb{S}_{r})=\frac{1}{|\mathbb{S}_{r}|}\left|\left\{R\in\mathbb{S}_{r}\mid\mathrm{RMSD}(R,\hat{R})<\delta,\exists\hat{R}\in\mathbb{S}_{g}\right\}\right| & (9)\\ \\\mathrm{MAT}\left(\mathbb{S}_{g},\mathbb{S}_{r}\right)=\frac{1}{|\mathbb{S}_{r}|}\sum_{R\in\mathbb{S}_{r}}\operatorname*{min}_{\hat{R}\in\mathbb{S}_{g}}\mathrm{RMSD}(R,\hat{R}) \end{gathered}
COV(Sg,Sr)=∣Sr∣1
{R∈Sr∣RMSD(R,R^)<δ,∃R^∈Sg}
MAT(Sg,Sr)=∣Sr∣1R∈Sr∑R^∈SgminRMSD(R,R^)(9)
一个好的结果应该有一个高的COV分数和一个低的MAT分数。
5 代码
5.1 前向过程
首先,我们对节点和边进行编码,得到节点特征,边特征和全局特征。
onehot_x = one_hot_atoms(x) # 进行onehot编码
onehot_edge_attr = one_hot_bonds(edge_attr)
graph_idx = torch.arange(num_graphs).to(x.device) # 创建batch中图的索引[1,2,...,num_graphs]
edge_batch = torch.repeat_interleave(graph_idx, num_edges, dim=0) # 创建一个大小为(num_edges,)的向量,其中的每个元素都是相应边所在的图的索引。
x_embed = self.encoder_node(onehot_x) # 编码节点
edge_attr_embed = self.encoder_edge(onehot_edge_attr) # 编码边
u_embed = self.global_init.expand(num_graphs, -1) # 将self.global_init中的参数在第0维上进行复制,生成一个大小为(num_graphs, latent_size)的u_embed。
# 包含每个图的全局特征u_embed
extra_output = {}
上述三个特征需要经过编码器 φ 2 D \varphi_{2\mathrm{D}} φ2D。注意编码器 φ 2 D \varphi_{2\mathrm{D}} φ2D的输入其实有四个,还有一个是初始构象。对于2D编码器,输入的初始构象为[-1,1]的噪声。特征在经过不断的消息传递之后,得到最终的四个参数 H V ( 0 ) , H E ( 0 ) , U ( 0 ) , R ^ ( 0 ) H_V^{(0)},H_E^{(0)},U^{(0)},\hat{R}^{(0)} HV(0),HE(0),U(0),R^(0)。
# prior conf
cur_pos = x_embed.new_zeros((x_embed.size(0), 3)).uniform_(-1, 1) # 创建一个与 x_embed形状相同的tensor,用于保存当前节点的坐标。uniform_()将其填充为在[-1, 1]之间的随机坐标。
pos_list = []
x = x_embed # 接受了参数,为特定值
edge_attr = edge_attr_embed # 接受了参数,为特定值
u = u_embed # 初始化为0
for i, layer in enumerate(self.prior_conf_gnns): # 消息传递
extended_x, extended_edge_attr = self.extend_x_edge(cur_pos, x, edge_attr, edge_index) # 嵌入边特征和节点特征 cur_pos->x
x_1, edge_attr_1, u_1 = layer( # 经过消息传递层
extended_x,
edge_index,
extended_edge_attr,
u,
edge_batch,
node_batch,
num_nodes,
num_edges,
)
x = F.dropout(x_1, p=self.dropout, training=self.training) + x # 加上消息传递并dropout后的值
edge_attr = F.dropout(edge_attr_1, p=self.dropout, training=self.training) + edge_attr
u = F.dropout(u_1, p=self.dropout, training=self.training) + u
if self.pred_pos_residual:
delta_pos = self.prior_conf_pos[i](x)
cur_pos = self.move2origin(cur_pos + delta_pos, batch)
else:
cur_pos = self.prior_conf_pos[i](x)
cur_pos = self.move2origin(cur_pos, batch)
cur_pos = self.random_augmentation(cur_pos, batch)
pos_list.append(cur_pos)
# extra_output["prior_last_edge"] = self.dist_pred_layer(edge_attr)
extra_output["prior_pos_list"] = pos_list
prior_output = [x, edge_attr, u]
在训练过程中,需要加入编码器 φ 3 D \varphi_{3\mathrm{D}} φ3D。其输入的节点特征,边特征和全局特征与编码器 φ 2 D \varphi_{2\mathrm{D}} φ2D相同。但是编码器 φ 3 D \varphi_{3\mathrm{D}} φ3D的构象输入为分子的真实3D构象,将它们经过编码后得到 μ R , G , Σ R , G \mu_{R,G},\Sigma_{R,G} μR,G,ΣR,G,创建噪声z。
if not sample:
# encoder
x = x_embed
edge_attr = edge_attr_embed
u = u_embed
cur_pos = self.move2origin(batch.pos, batch)
if not self.no_3drot:
cur_pos = get_random_rotation_3d(cur_pos)
for i, layer in enumerate(self.encoder_gnns): # 编码器编码
extended_x, extended_edge_attr = self.extend_x_edge(
cur_pos, x, edge_attr, edge_index
)
x_1, edge_attr_1, u_1 = layer(
extended_x,
edge_index,
extended_edge_attr,
u,
edge_batch,
node_batch,
num_nodes,
num_edges,
)
x = F.dropout(x_1, p=self.dropout, training=self.training) + x
edge_attr = (
F.dropout(edge_attr_1, p=self.dropout, training=self.training) + edge_attr
)
u = F.dropout(u_1, p=self.dropout, training=self.training) + u
if self.use_ss:
cur_pos = get_random_rotation_3d(cur_pos)
if self.use_global:
aggregated_feat = u
else:
aggregated_feat = self.pooling(x, node_batch)
if self.use_ss:
extra_output["query_feat"] = self.prediction_head(
self.projection_head(aggregated_feat)
)
with torch.no_grad():
x = x_embed
edge_attr = edge_attr_embed
u = u_embed
for i, layer in enumerate(self.encoder_gnns):
extended_x, extended_edge_attr = self.extend_x_edge(
cur_pos, x, edge_attr, edge_index
)
x_1, edge_attr_1, u_1 = layer(
extended_x,
edge_index,
extended_edge_attr,
u,
edge_batch,
node_batch,
num_nodes,
num_edges,
)
x = F.dropout(x_1, p=self.dropout, training=self.training) + x
edge_attr = (
F.dropout(edge_attr_1, p=self.dropout, training=self.training)
+ edge_attr
)
u = F.dropout(u_1, p=self.dropout, training=self.training) + u
cur_pos = get_random_rotation_3d(cur_pos)
if self.use_global:
aggregated_feat_1 = u
else:
aggregated_feat_1 = self.pooling(x, node_batch)
extra_output["key_feat"] = self.projection_head(aggregated_feat_1)
latent = self.encoder_head(aggregated_feat)
latent_mean, latent_logstd = torch.chunk(latent, chunks=2, dim=-1)
extra_output["latent_mean"] = latent_mean
extra_output["latent_logstd"] = latent_logstd
z = self.reparameterization(latent_mean, latent_logstd)
else:
z = torch.randn_like(u_embed) * self.sample_beta
z = torch.repeat_interleave(z, num_nodes, dim=0)
if self.reuse_prior:
x, edge_attr, u = prior_output
else:
x, edge_attr, u = x_embed, edge_attr_embed, u_embed
最后,我们将得到的噪声z加入到节点特征中,输入解码器,得到分子构象。噪声z通过与编码过的节点特征x相加进行结合。
for i, layer in enumerate(self.decoder_gnns):
if i == len(self.decoder_gnns) - 1:
cycle = self.cycle
else:
cycle = 1
for _ in range(cycle):
extended_x, extended_edge_attr = self.extend_x_edge(
cur_pos, x + z, edge_attr, edge_index
)
x_1, edge_attr_1, u_1 = layer(
extended_x,
edge_index,
extended_edge_attr,
u,
edge_batch,
node_batch,
num_nodes,
num_edges,
)
x = F.dropout(x_1, p=self.dropout, training=self.training) + x
edge_attr = (
F.dropout(edge_attr_1, p=self.dropout, training=self.training) + edge_attr
)
u = F.dropout(u_1, p=self.dropout, training=self.training) + u
if self.pred_pos_residual:
delta_pos = self.decoder_pos[i](x)
cur_pos = self.move2origin(cur_pos + delta_pos, batch)
else:
cur_pos = self.decoder_pos[i](x)
cur_pos = self.move2origin(cur_pos, batch)
cur_pos = self.random_augmentation(cur_pos, batch)
pos_list.append(cur_pos)
if self.sg_pos:
cur_pos = cur_pos.detach()
# extra_output["decoder_last_edge"] = self.dist_pred_layer(edge_attr)
return pos_list, extra_output
5.2 损失函数
def compute_loss(self, pos_list, extra_output, batch, args):
loss_dict = {}
loss = 0
pos = batch.pos
new_idx = GNN.update_iso(pos, pos_list[-1], batch) # 更新节点索引以进行同构性对齐操作
loss_tmp, _ = self.alignment_loss( # 计算当前位置 pos 与先前位置extra_output["prior_pos_list"][-1]之间的对齐损失
pos, extra_output["prior_pos_list"][-1].index_select(0, new_idx), batch
)
loss = loss + loss_tmp
loss_dict["loss_prior_pos"] = loss_tmp # 旋转平移loss
mean = extra_output["latent_mean"]
log_std = extra_output["latent_logstd"]
kld = -0.5 * torch.sum(1 + 2 * log_std - mean.pow(2) - torch.exp(2 * log_std), dim=-1) # 算 KL 散度损失
kld = kld.mean()
loss = loss + kld * args.vae_beta # 乘以β
loss_dict["loss_kld"] = kld
loss_tmp, _ = self.alignment_loss(
pos, pos_list[-1].index_select(0, new_idx), batch, clamp=args.clamp_dist
)
loss = loss + loss_tmp
loss_dict["loss_pos_last"] = loss_tmp # 算当前位置 pos 与最后一次更新后的位置 pos_list[-1] 之间的对齐损失
if args.aux_loss > 0:
for i in range(len(pos_list) - 1):
loss_tmp, _ = self.alignment_loss(
pos, pos_list[i].index_select(0, new_idx), batch, clamp=args.clamp_dist
)
loss = loss + loss_tmp * (args.aux_loss if i < len(pos_list) - args.cycle else 1.0)
loss_dict[f"loss_pos_{i}"] = loss_tmp # 对于每个先前位置 pos_list[i](不包括最后一个位置),计算当前位置 pos 与先前位置之间的对齐损失
if args.ang_lam > 0 or args.bond_lam > 0:
bond_loss, angle_loss = self.aux_loss(pos, pos_list[-1].index_select(0, new_idx), batch)
loss_dict["bond_loss"] = bond_loss
loss_dict["angle_loss"] = angle_loss
loss = loss + args.bond_lam * bond_loss + args.ang_lam * angle_loss # 键合损失 bond_loss 和角度损失 angle_loss
if self.use_ss:
anchor = extra_output["query_feat"]
positive = extra_output["key_feat"]
anchor = anchor / torch.norm(anchor, dim=-1, keepdim=True)
positive = positive / torch.norm(positive, dim=-1, keepdim=True)
loss_tmp = torch.einsum("nc,nc->n", [anchor, positive]).mean()
loss = loss - loss_tmp
loss_dict[f"loss_ss"] = loss_tmp # 计算自监督损失。计算锚点特征 anchor 和正样本特征 positive 的内积,取均值。
loss_dict["loss"] = loss
return loss, loss_dict
6 思考与改进
加入diffusion的思想,在模型架构之中嵌入diffusion。