ProteinMPNN的神经网络模型,主要用于处理蛋白质相关的数据。模型包括特征提取部分(ProteinFeatures)、编码器层(EncLayer)和译码器层(DecLayer)
ProteinMPNN forward函数的部分代码:
# Concatenate sequence embeddings for autoregressive decoder
h_S = self.W_s(S)
h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
# Build encoder embeddings
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
# ...
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
for layer in self.decoder_layers:
h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
h_V = torch.utils.checkpoint.checkpoint(layer, h_V, h_ESV, mask)
-
代码解读:
-
h_EXV_encoder
没有序列信息:- 没错,
h_EXV_encoder
是从编码器得到的图结构的上下文信息,反映的是整个蛋白质结构的空间关系,但它不包含具体的序列信息(即氨基酸序列)。在解码过程中,这个信息可以帮助模型理解空间位置之间的相互作用。
- 没错,
-
mask_fw * h_EXV_encoder
得到当前位置以后的信息,但没有序列信息:mask_fw
用于选择前向的(未来的)位置。mask_fw * h_EXV_encoder
只包含当前位置及以后的结构信息(不含有序列信息)。这确保模型不会提前看到还未预测的序列,但可以利用结构上的上下文。
-
mask_bw * h_ESV
是当前位置及以前的信息,包括序列信息:mask_bw
作用于h_ESV
,这其中的h_ESV
包含了序列嵌入(h_S
)和邻域信息(h_E
)。mask_bw * h_ESV
会屏蔽未来位置,但保留当前和过去的序列信息。所以,这一部分反映了模型已经看到的序列内容及其相应的结构信息。
-
mask_bw * h_ESV + h_EXV_encoder_fw
屏蔽了当前位置以后的序列信息,但含有所有位置的其它信息:- 通过
mask_bw * h_ESV
,我们确保仅使用过去和当前的序列信息,而h_EXV_encoder_fw
提供了编码器输出的结构信息(但没有序列),这样就能够屏蔽未来的序列信息(保持自回归的要求),同时结合所有位置的结构信息来进行预测。
- 通过
总结来说,这种做法确保在每个解码步骤中,模型不会泄露未来的序列信息,但仍然可以利用整个蛋白质结构的空间上下文,这对蛋白质序列预测非常重要。