De Novo Prediction of RNA 3D Structures with Deep Learning
Year: 2022
Authors: Julius Ramakers, Christopher Frederik Blum, Sabrina K¨onig, Stefan Harmeling, Markus Kollmann
Journal Name: bioxiv
1 Innovation
结合自回归深度生成模型、蒙特卡罗树搜索和分数模型预测 RNA 三维折叠结构。
2 Method
2.1 Neural Network input data format
每个残基( residue )都能用 5 原子位置表示, P, C4’, C2, C6, N9 表示 G/A , P, C4’, C2, C4, N1 表示 C/U ,每个残基都能编码为
8
×
3
8 \times 3
8×3 的坐标矩阵和
8
×
1
8 \times 1
8×1 的掩码数组,如下表所示
2.2 Distance tensors( Euclidian distances )
对于长度为 L L L 的 RNA 序列,计算每个被编码原子之间欧氏距离,结果为 8 L × 8 L 8L \times 8L 8L×8L 的距离矩阵(残基中只存在 5 种原子, 3 种不存在原子之间的距离记为 0 ),之后该矩阵 reshape 为 L × L × 64 L \times L \times 64 L×L×64 的距离张量( Euclidian distances ) D D D 。
2.3 Structures with long or multiple chains
当结构具有多个链时,以与它的长度成正比的概率随机选择一个链并用作子结构,如果该链的长度超过 100 个残基,则随机截取长度为 100nt 的部分用作子结构。对该子结构中的每个残基,计算与子结构外所有残基的距离。对于存在距离低于 3.3 A ˚ 3.3 \mathring{A} 3.3A˚ 的残基被标记为 “fixed”,并且 “fixed” 残基所对应的距离类别( distance classes )在生成器训练期间在输入端呈现。
2.4 Data augmentation
作者反向使用 SimRNA ,从原始结构开始提高温度以“远离”原始结构,为每个训练结构生成 100 个新数据,它们与原始结构相距大约 1 , 3 , 5 , 10 A ˚ 1, 3, 5, 10\mathring{A} 1,3,5,10A˚ RMSE。
2.5 Distance classes
作者使用向量量化变分自动编码器( Vector Quantised Variational Autoencoder, VQ-VAE) ,将残基中原子之间的欧式距离压缩为 K = 3 K = 3 K=3 个类别,这 3 个类别与距离度量 “near”, “intermediate”, “far” 非常吻合,所以称之为距离类别( distance classes )。文章中没有说明为什么吻合。编码器输出大小为 L × L × 8 L \times L \times 8 L×L×8 ,码本中含有 3 个向量,于是得到 L × L L \times L L×L 的 index 。为了保持对称性。之后,将 index 沿着前两个维度转置后相加除以 2 ,将 index 变成 one hot 向量得到 L × L × 3 L \times L \times 3 L×L×3 大小的张量,这个张量即为深度生成模型的目标。
2.6 Generator Network: Data Preprocessing
将以下四种张量堆叠起来作为生成器的输入
- L × L × 8 L \times L \times 8 L×L×8 的编码后 RNA 序列张量,和 SPOT-RNA 的输入类似,其中未知核苷酸的 one hot 编码内容全为 0.25 。对于长度小于 100 的序列,张量统一填充 -1 。
- L × L × 3 L \times L \times 3 L×L×3 的部分掩码的 distance classes 张量,即将 VQ-VAE 编码的 distance classes 张量的部分像素置为 0 。从均值为 L 2 / 2 L^2 / 2 L2/2 ,标准差为 L 2 / 4 L^2 / 4 L2/4 ,以均值周围 2 个标准差为界的截断正态分布中选取掩码像素个数。作者发现神经网络较难区分掩码类别 (0, 0, 0) 和正常类别 (1, 0, 0), (0, 1, 0), (0, 0, 1) ,于是将正常类别中的 0 变为 -1 。掩码的目的是使模型更具有泛化性。
- L × L × 2 L \times L \times 2 L×L×2 的坐标系框架包含“对角框架”(对角线为 1 ,其余为 0 )和“填充框架”(边界为 1 ,其余为 0 )。对于长度小于 100 的序列,框架统一填充 -1 。
- L × L × 1 L \times L \times 1 L×L×1 的同源序列和 SHAPE 数据的注意力图,这部分将在下节详细说明。
2.7 Generator Network: Attention map of homologous sequences and SHAPE
对于每个训练样本,搜索 50 个 one hot 编码的同源序列,如果同源序列不足 50 ,则采取原始序列作为填充,获得 L × 50 × 4 L \times 50 \times 4 L×50×4 大小的张量,再使用全连接层将其映射至大小为 L × 50 L \times 50 L×50 的张量,再将 SHAPE( selective hydroxyl acylation analyzed by primer extension ) 反应值拼接进去,得到大小为 L × 51 L \times 51 L×51 的张量。作者采用 transformer 中的自注意力机制得到大小为 L × L × 1 L \times L \times 1 L×L×1 的注意力图, query 和 key 的大小为 L × 64 L \times 64 L×64 。引入 value 效果会不会更好?
2.8 Generator Network: Architecture
生成器主要为加残差的卷积神经网络,具体结构详见论文 Supplementary Information 中的 Generator Network: Architecture ,最后输出大小为 L × L × 3 L \times L \times 3 L×L×3 ,目的是预测 distance class map 。
2.9 Score Model
分数模型的目的是区分两个 distance class map 哪一个更好,输出
L
×
L
×
1
L \times L \times 1
L×L×1 的 logit map,模型结构详见论文 Supplementary Information 中的 Score Model 。在训练中,输入一个正确(原始)的 distance class map 和相对应不正确(通过数据增广得到)的 distance class map 。
训练目标为最大化函数
J
(
D
)
=
E
s
N
∼
P
t
r
u
e
(
s
N
,
x
)
,
s
N
′
∼
P
f
a
l
s
e
(
s
N
′
,
x
)
[
l
o
g
D
(
s
N
,
s
N
′
;
x
)
]
+
E
s
N
∼
P
t
r
u
e
(
s
N
,
x
)
,
s
N
′
∼
P
f
a
l
s
e
(
s
N
′
,
x
)
[
l
o
g
(
1
−
D
(
s
N
′
,
s
N
;
x
)
)
]
D
(
s
N
,
s
N
′
;
x
)
=
1
1
+
e
x
p
[
f
(
s
N
′
,
x
)
−
f
(
s
N
,
x
)
]
J(D) = \mathbb{E}_{s_N \sim P_{true}(s_N, x), s_N' \sim P_{false}(s_N', x)}[logD(s_N, s_N'; x)] + \\ \mathbb{E}_{s_N \sim P_{true}(s_N, x), s_N' \sim P_{false}(s_N', x)}[log(1-D(s_N', s_N; x))] \\ D(s_N, s_N'; x) = \frac{1}{1 + exp[f(s_N', x) - f(s_N, x)]}
J(D)=EsN∼Ptrue(sN,x),sN′∼Pfalse(sN′,x)[logD(sN,sN′;x)]+EsN∼Ptrue(sN,x),sN′∼Pfalse(sN′,x)[log(1−D(sN′,sN;x))]D(sN,sN′;x)=1+exp[f(sN′,x)−f(sN,x)]1
其中, f ( s N , x ) f(s_N, x) f(sN,x) 为神经网络的标量输出, s N s_N sN 为正确样本, s N ′ s_N' sN′ 为错误样本, x x x 为序列信息。
不正确的 logit map 减去正确的 logit map ,然后对所有差异求和。可不可以仿照 ARES ,将模型的输出变为与真实结构的差距?
2.10 MCTS: Sampling structural ensembles
MCTS 算法使用生成器迭代地对三个 distance classes 进行像素采样。首先,我们设置对角的 distance class 为 “near” 。接着,使用 MCTS 迭代的为剩下的像素添加 distance classes 。通常,当 MCTS 能够正确预测 30% 的 distance classes 时,生成器会产生足够清晰的预测。
剩下的像素索引记作
i
∈
{
1
,
.
.
.
,
N
}
,
N
=
L
(
L
−
1
)
/
2
i \in \{ 1, ..., N \}, N = L(L-1)/2
i∈{1,...,N},N=L(L−1)/2 ,
i
i
i 像素的 distance classes 记作
k
i
∈
{
1
,
.
.
.
,
K
}
k_i \in \{ 1, ..., K \}
ki∈{1,...,K} ,直到
t
t
t 时刻已经进行采样的像素集合称为状态
s
t
s_t
st ,
t
t
t 时刻生成器的输出为
P
(
k
i
∣
s
t
,
x
)
P(k_i | s_t, x)
P(ki∣st,x) 。根据生成器的输出做出行动
a
t
a_t
at (进行像素采样),所以
s
t
=
(
a
0
,
.
.
.
,
a
t
)
s_t = (a_0, ..., a_t)
st=(a0,...,at) 。
2.11 Structural Sampling
当模型处于叶节点,进行选择阶段时(选择从根节点
s
0
s_0
s0 到叶节点
s
L
s_L
sL 的路径),策略为
a
t
=
arg max
a
(
Q
(
s
t
,
a
∣
x
)
+
U
(
s
t
,
a
)
)
Q
(
s
t
,
a
∣
x
)
=
1
L
∑
t
=
1
L
H
(
s
0
∣
x
)
−
H
(
s
t
∣
x
)
H
(
s
0
∣
x
)
H
(
s
∣
x
)
=
−
∑
j
=
1
N
∑
k
j
=
1
K
P
(
k
j
∣
s
,
x
)
l
o
g
P
(
k
j
∣
s
,
x
)
U
(
s
t
,
a
)
=
c
∑
a
N
(
s
t
,
a
)
1
+
N
(
s
t
,
a
)
a_t = \argmax_a (Q(s_t, a | x) + U(s_t, a)) \\ Q(s_t, a | x) = \frac{1}{L} \sum_{t=1}^L \frac{H(s_0 | x) - H(s_t | x)}{H(s_0 | x)} \\ H(s | x) = - \sum_{j=1}^N \sum_{k_j = 1}^K P(k_j | s, x) logP(k_j | s, x) \\ U(s_t, a) = c \frac{\sqrt{\sum_a{N(s_t, a)}}}{1 + N(s_t, a)}
at=aargmax(Q(st,a∣x)+U(st,a))Q(st,a∣x)=L1t=1∑LH(s0∣x)H(s0∣x)−H(st∣x)H(s∣x)=−j=1∑Nkj=1∑KP(kj∣s,x)logP(kj∣s,x)U(st,a)=c1+N(st,a)∑aN(st,a)
其中
Q
Q
Q 表示熵减率,
N
(
s
t
,
a
)
N(s_t, a)
N(st,a) 表示在
s
t
s_t
st 状态执行动作
a
a
a 的次数,初始化为 1 ,
c
c
c 为超参数。
当到达叶节点后,进行扩展阶段。剩余像素子集记作
S
R
S_R
SR ,随机选取子集
S
H
⊂
S
R
S_H \subset S_R
SH⊂SR 使得熵减
Δ
H
L
<
λ
l
o
g
K
\Delta H_L < \lambda logK
ΔHL<λlogK ,本文中作者设
λ
=
1
\lambda = 1
λ=1 。如果在采样过程中熵减达不到阈值怎么办?如果在扩展阶段采用
ϵ
\epsilon
ϵ 贪婪策略会不会更有效率?
Δ
H
=
H
(
s
L
+
1
∣
x
)
−
H
(
s
L
∣
x
)
\Delta H = H(s_{L+1} | x) - H(s_L | x)
ΔH=H(sL+1∣x)−H(sL∣x)
在回溯阶段
N
(
s
t
,
a
t
)
←
N
(
s
t
,
a
t
)
+
∣
S
H
∣
Q
(
s
t
,
a
t
∣
x
)
←
Q
(
s
t
,
a
t
∣
x
)
+
∣
S
H
∣
N
(
s
t
,
a
t
)
(
Q
(
s
L
+
1
,
a
t
∣
x
)
−
Q
(
s
t
,
a
t
∣
x
)
)
N(s_t, a_t) \leftarrow N(s_t, a_t) + | S_H | \\ Q(s_t, a_t | x) \leftarrow Q(s_t, a_t | x) + \frac{|S_H|}{N(s_t, a_t)}(Q(s_{L+1}, a_t | x) - Q(s_t, a_t | x))
N(st,at)←N(st,at)+∣SH∣Q(st,at∣x)←Q(st,at∣x)+N(st,at)∣SH∣(Q(sL+1,at∣x)−Q(st,at∣x))
本文中
∣
S
H
∣
=
10
| S_H | = 10
∣SH∣=10 。
最后,通过分数模型辨别终止叶节点的好坏,使用 VAE 的解码器获得 RNA 结构,接着通过最小化粗粒度分子能量方程进行改进( Boniecki, M. et al. Simrna: a coarse-grained method for rna folding simulations and 3d structure prediction. Nucleic Acids Res 20;44(7):e63 (2016). )。