之前分享的EVE利用MSA来训练VAE模型,Tranception在UniRef100上训练自回归模型,二者都是无监督学习,即不依赖任何标签,通过对蛋白序列的挖掘来学习突变效应。ProteinNPT于2023年发表于NeurIPS Proceedings,能够将无监督预测作为辅助标签,和真实标签一起用于训练网络,做半监督学习。截止到2024.10.08,ProteinNPT于是ProteinGym DMS Substitutions数据上有监督学习的最佳模型。
数据
在ProteinGym上进行交叉检验,设计了三种五折交叉检验的划分方式:
- Random:每个突变随机分配,可能会导致同一位置的突变被分配到不同fold
- Contiguous:仅考虑包含突变的位置,将序列分割为连续的片段,确保每个片段中包含同样多的突变位置
- Modulo:仅考虑包含突变的位置,对于 5 折交叉验证,位置 1 被分配给fold 1,位置 2 分配给fold 2,…,位置 6 分配给fold 1,以此类推。
仅在评估单突变时才使用Contiguous和Modulo交叉验证方案。多突变仅基于Random交叉验证方案。
模型
模型采用MSA Transformer预训练的embedding(MSA Transformer内部冻结不参与训练),线性投影到 d d d=200。序列的真实标签和辅助标签(zero-shot preditions from MSA Transformer)经过标准缩放后也线性投影到 d d d=200,和序列embedding拼接起来。对batch B中的每条序列,获得$(L_{seq}+2) \cdot d $形式的input embed,经过CNN非线性化后输入5个连续的ProteinNPT layers。
每个 ProteinNPT 层依次row-attention、column-attention和前馈层。每个变换前都有 LayerNorm,每个变换后都有残差连接。
最后,通过输入L2-penalized linear projector来进行预测。目标函数包括去噪和预测两部分:
L total = α t ⋅ L AA reconstruction + ( 1 − α t ) ⋅ L target prediction \mathcal{L}^{\text {total }}=\alpha_t \cdot \mathcal{L}^{\text {AA reconstruction }}+\left(1-\alpha_t\right) \cdot \mathcal{L}^{\text {target prediction }} Ltotal =αt⋅LAA reconstruction +(1−αt)⋅Ltarget prediction
ProteinNPT layers代码如下:
from ..utils.esm.modules import AxialTransformerLayer
self.layers = nn.ModuleList(
[
AxialTransformerLayer(
self.args.embed_dim,
self.args.ffn_embed_dim,
self.args.attention_heads,
self.args.dropout,
self.args.attention_dropout,
self.args.activation_dropout,
getattr(self.args, "max_tokens_per_msa", self.args.max_tokens_per_msa),
self.deactivate_col_attention,
self.tranception_attention,
self.num_targets_input,
)
for _ in range(self.args.num_protein_npt_layers)
]
)
训练与推理
训练中,batch size取425,PLM embedding提前存储进硬盘,每次随机屏蔽氨基酸和标签各64。使用AdamW优化器,测试各种early stopping策略后选择使用固定10k的training steps。
推理时,将训练集与测试集row-wise拼接,类似Chain of Thought,使得模型进行Colunmn attention时了解序列-标签关系和有标注序列之间的同源性。拼接对训练集设定最大数量M,如果训练数据集较大,随机不替换采样M条训练数据(随机替换采样经测试效果相近)。
消融实验
消融实验全部在8个较小的DMS assay上进行(见B.2),统一使用Spearman’s rank correlation作为比较指标。
- 使用哪个PLM:比较了不使用PLM、ESM-1v、Tranception和MSA Transformer,使用PLMembedding的同时选择对应的zero-shot预测做辅助标签。在三种交叉检验方式上MSA Transformer表现均为最佳。
- 是否使用辅助标签:使用MSA Transformer预测后,在三种交叉检验方式上模型性能都获得提高。
- post embedding非线性层架构:在随机交叉检验上比较了Light attention、Linear、MLP、ConvBERT、CNN后选择CNN
- 推理时M取值:相比于不添加训练数据,添加100条即在三种交叉检验方式上均有有性能提升,添加1000以上时不再有性能提升,故M取1k。
- 训练集大小:Small size (≤2k labels),Medium size (2-8k labels) and Large size (>8k labels),未观察到small assays上性能显著下降(medium表现最好)。
性能
ProteinGym
比较了ProteinNPT和以下模型:
- Zero-shot (MSA Transformer):无修改,5个模型取ensemble
- OHE: ridge regression on trained one-hot-encodings
- OHE - Augmented (DeepSequence): OHE + zero-shot predictions from DeepSequence
- OHE - Augmented (MSA Transformer): OHE + zero-shot predictions from the MSA Transformer
- Embeddings - Augmented (MSA Transformer):该团队增强后的mean pooled embeddings from the MSA Transformer + zero-shot predictions from the MSA Transformer,见C.1。
在单突变、多突变、多分类三个不同任务上,ProteinNPT都取得了最优表现。
Protein Redesign
模拟重设计中,从小部分有标签数据开始训练模型,对所有(总数据中存在的)未标记数据根据Upper Confidence Bound (UCB) 打分,获取最佳的少量数据标签加入训练集,重新训练-评分-获取标签。UCB函数如下:
a ( x ; λ ) = μ ( x ) + λ ⋅ σ ( x ) a(x ; \lambda)=\mu(x)+\lambda \cdot \sigma(x) a(x;λ)=μ(x)+λ⋅σ(x)
其中 μ \mu μ代表预测的适应性, σ \sigma σ代表预测的不确定性, λ \lambda λ做探索-利用的权衡。
采用5 Monte Carlo dropout samples for 5 resample inference batches混合的策略来衡量不确定性:
- Monte Carlo dropout:在推理时使用固定的推理批次,使用多个前向传递的预测标准差作为不确定性度量
- Batch resampling:对于待预测的数据点,用有替换的方式采样不同样本与该数据点组成不同的input batches,取前向传递后的标准差。
对 ProteinGym 扩展基准中的所有 DMS 测定进行迭代蛋白质重设计实验。仅考虑单突变,根据assay包含的数据量将assay分为三组:
- 对< 1250 的assay,从 50 个有标签数据开始,并在每次迭代中获取 N = 50 个新标签。
- 对>1250 且< 2500 的assay,从 100 个有标签数据开始,并在每次迭代中获取 N = 100 个新标签。
- 对>2500 的assay,从50 个有标签数据开始,并在每次迭代中获取N = 200 个新标签。
运行三次以保证稳定。ProteinGym几乎在所有assay中超过基线,且在数据量更大的assay上表现更好。
Conditional Sampling
以GFP DMS assay为例,首先识别高适应性序列(top 3 deciles),然后构建输入:随机选择其他有标签序列+对最优序列做5个掩码,输入模型后获取掩码位置的log softmax分布,采样、获取新序列。
注意,掩码位置不是随机选择,而是采样在ProteinNPT最后一层与标签的row-wise注意力分数最高的位置。
使用ESM-1v评价条件采样结果。