【AI模型学习】ESM2

1. 版本

ESM-2 一共有多个版本,主要区别在于:
层数(depth)参数量(size)推理速度和精度权衡
这些版本都遵循相同的 Transformer 编码器架构,只是在大小和计算能力上有差异。

版本一览

模型名称(Hugging Face 名)层数参数量说明
esm2_t6_8M_UR50D68M极小模型,适合快速原型
esm2_t12_35M_UR50D1235M中等小型,推荐用于入门任务
esm2_t30_150M_UR50D30150M中等模型,效果与效率平衡
esm2_t33_650M_UR50D33650M较大模型,适合更复杂任务
esm2_t36_3B_UR50D363B超大模型,强大建模能力
esm2_t48_15B_UR50D4815B最大模型,性能最强但最重

模型名一般遵循这个格式:

esm2_t<层数>_<参数量>_UR50D
  • t<层数>:比如 t12 表示有 12 层 Transformer 编码器。
  • <参数量>:大概的参数数量,比如 35M, 3B 等。
  • UR50D:代表训练用的数据集(Uniref50),是去冗余后的蛋白质数据库。
需求场景推荐模型
快速测试 / 教学 / CPU 调用esm2_t6_8Mt12_35M
标准下游任务建模esm2_t30_150M
结构预测 / 蛋白功能预测esm2_t33_650M 及以上
追求最强性能esm2_t48_15B

注意:3B 和 15B 模型通常需要 A100 等大显存 GPU,或者分布式推理支持。


2. 开始

2.1 安装

pip install fair-esm  # latest release, OR:
pip install git+https://github.com/facebookresearch/esm.git  # bleeding edge, current repo main branch

2.2 使用预训练模型

2.2.1 代码

我们直接拿官网的作为示例

import torch
import esm

# 加载 ESM-2 模型
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # 设置为评估模式(关闭 dropout,确保结果可复现)

# 准备数据(来自 ESMStructuralSplitDataset 数据集的前两个蛋白质序列)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

# 提取每个残基的表示(使用 CPU 推理)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

# 通过对每个序列的残基取平均,生成序列级别的表示
# 注意:token 0 是序列起始符号 <cls>,第一个氨基酸是 token 1
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))

# 查看模型中无监督注意力图生成的接触预测图
import matplotlib.pyplot as plt
for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
    plt.matshow(attention_contacts[: tokens_len, : tokens_len])
    plt.title(seq)
    plt.show()


2.2.2 讲解

results 的结构总览:

results.keys()
# dict_keys(['logits', 'representations', 'attentions', 'contacts'])

其中包含了 4 个主要部分,分别是:

  1. logitstorch.Size([4, 73, 33])

含义:每个位置的分类 logits,用于预测氨基酸(或 <mask> 的掩码预测)。

  • [4, 73, 33]

    • 4:表示 batch 中有 4 个序列
    • 73:是 padding 后的最大 token 长度
    • 33:是氨基酸 vocabulary 的大小(包括 <mask><pad> 等)

用途

  • 如果你用 <mask>,这个张量可以用于做“突变打分”或“掩码填空”
  • 可以通过 torch.nn.functional.softmax(logits, dim=-1) 得到每个位置的预测概率分布
  1. representationsdict,键是层号,值是 embedding

你指定了 repr_layers=[33],所以返回了:

representations = {
    33: torch.Size([4, 73, 1280])
}

含义:第 33 层(即最后一层)输出的每个 token 的 embedding 表示。

  • 维度 [4, 73, 1280]

    • 4:batch 大小
    • 73:token 序列长度(含 <cls><eos>
    • 1280:embedding 的维度(模型隐藏层大小)

用途

  • 提取每个残基的表示用于下游任务(分类、聚类、结构预测)
  • 可以对 1~L 之间的向量取平均,生成序列级表示
  1. attentionstorch.Size([4, 33, 20, 73, 73])

含义:每一层、每个头的 self-attention 权重

  • [4, 33, 20, 73, 73]

    • 4:batch 中 4 条序列
    • 33:ESM2 的 transformer 层数
    • 20:每层的 attention head 数量
    • 73 x 73:每个 head 的注意力矩阵

用途

  • 可视化每层每个头的注意力
  • 为接触图预测提供基础(即下一个)
  1. contactstorch.Size([4, 71, 71])

含义:预测的残基接触图(非监督 attention 平均生成的)

  • 维度 [4, 71, 71]:对应于每个序列的残基之间的接触概率

为什么是 71 而不是 73?
因为 73 包含了 <cls><eos>,它们会被自动排除,真正的残基只有 71 个。

用途

  • 可以直接作为结构接触预测图的初步结果
  • 在结构预测任务中可用于辅助建模残基之间的关系

总结(表格版)

类型维度含义
logitsTensor[B, L, V]每个 token 的分类 logits,用于 <mask> 推断
representationsdict{层号: Tensor[B, L, D]}每层 token 表示,通常用最后一层
attentionsTensor[B, L, H, T, T]每层每个 head 的注意力矩阵
contactsTensor[B, L', L']每个序列的残基接触预测图(不含 /)

2.2 结构预测

import torch
import esm

# 这里中间会有几个库需要pip一下,看报错信息补好即可
# 不过好像。。这里面的问题有点小多。。。然后就是如果想做这一块的话,python版本放到3.9及以下
model = esm.pretrained.esmfold_v1()
model = model.eval().cuda()

# 可选:取消注释以下语句以设置轴向注意力的块大小。这可以帮助降低显存占用。
# 块越小,显存占用越低,但计算速度可能会变慢。
# model.set_chunk_size(128)

sequence = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
# 多聚体预测时,可以用 ':' 分隔不同链

with torch.no_grad():
    output = model.infer_pdb(sequence)

with open("result.pdb", "w") as f:
    f.write(output)

import biotite.structure.io as bsio
struct = bsio.load_structure("result.pdb", extra_fields=["b_factor"])
print(struct.b_factor.mean())  # 这将输出 pLDDT 分数的平均值
# 88.3

3. 任务类型总结

1. 蛋白质结构预测(ESMfold)

输入

  • 单序列:氨基酸序列字符串(如 "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG")。
  • 多聚体:用 : 分隔链的序列(如 "chainA:chainB")。
  • 批量输入:FASTA 文件(含多条序列)。

输出

  • 单序列结果:PDB 格式的蛋白质结构文件,包含原子坐标和 pLDDT 评分(预测置信度)。
    # 示例代码输出  
    output = model.infer_pdb(sequence)  # PDB 文本字符串  
    
  • 批量结果:指定目录下的多个 PDB 文件,文件名对应 FASTA 序列 ID。

2. 特征嵌入提取(esm-extract)

输入

  • 模型:预训练模型名称(如 esm2_t33_650M_UR50D)或本地模型路径。
  • 序列数据:FASTA 文件(含一条或多条序列)。
  • 参数:指定提取的层(--repr_layers)和输出类型(--include mean/per_tok/contacts)。

输出

  • ** per_tok 嵌入**:每个残基的特征向量(形状为 [seq_len, hidden_dim])。
  • ** mean 嵌入**:序列全局平均特征(形状为 [hidden_dim])。
  • 接触预测:注意力图导出的接触概率矩阵(形状为 [seq_len, seq_len])。
  • 文件格式:每个序列对应一个 .pt 文件,存储为 PyTorch 张量。

3. 零镜头变体预测(ESM-1v/ESM-2)

输入

  • 野生型序列:氨基酸序列字符串。
  • 突变位点:如 "A123G"(第123位丙氨酸突变为甘氨酸)。
  • 批量输入:CSV 或 FASTA 文件,包含多组野生型-突变序列对。

输出

  • 功能影响评分:突变对蛋白质功能的预测效应(如稳定性、活性变化)。
  • 示例输出:对数概率或相对效应值,用于排序突变的有害性。

4. 逆向折叠(ESM-IF1)

4.1 序列设计(给定结构采样)

  • 输入
    • 结构文件:PDB 或 mmCIF 文件(含主链坐标,如 5YH2.pdb)。
    • 链选择:指定目标链(如 --chain C)。
    • 参数:采样温度(--temperature,控制序列多样性)。
  • 输出:FASTA 文件,包含生成的氨基酸序列(如 sampled_sequences.fasta)。

4.2 序列评分(给定结构评估)

  • 输入
    • 结构文件:PDB/mmCIF 文件。
    • 序列文件:FASTA 文件(含待评分序列)。
  • 输出:CSV 文件,包含每条序列的平均对数似然值(如 5YH2_mutated_seqs_scores.csv)。

5. 宏基因组图谱数据(ESM Atlas)

输入

  • 查询方式
    • 序列搜索:FASTA 序列,通过 API 或 Foldseek 搜索相似结构。
    • 结构搜索:PDB/mmCIF 文件或结构 ID,检索同源结构。
  • 批量下载:通过官网提供的链接下载全量结构数据(如.tar.gz压缩包)。

输出

  • 结构数据:PDB 文件或预计算的 ESM-2 嵌入(.npy 或 .pt 文件)。
  • 搜索结果:匹配的结构列表,包含相似性分数和功能注释。

6. 多序列比对分析(ESM-MSA-1b)

输入

  • MSA 数据:A3M 格式的多序列比对文件(含同源序列)。
  • 模型输入:通过 esm.pretrained.esm_msa1b_t12_100M_UR50S() 加载模型。

输出

  • MSA 特征:从比对中提取的进化保守性嵌入,用于增强结构预测或功能分析。
  • 接触预测:结合 MSA 信息的残基接触图,精度高于单序列模型。

7. 生成式蛋白质设计(ESM-2)

输入

  • 设计约束:自然语言描述(如“设计一个具有 ATP 结合位点的螺旋结构”)或编程指令(如示例中的蛋白质编程语言)。

输出

  • 全新序列:满足特定功能或结构约束的氨基酸序列,可通过 ESMfold 验证结构合理性。

数据格式总结表

任务输入格式输出格式关键工具/模型
结构预测单/多聚体序列、FASTAPDB 文件、pLDDT 评分ESMfold、esm-fold
特征提取FASTA、模型名称.pt 张量文件esm-extract
变体预测野生型/突变序列功能影响评分ESM-1v、ESM-2
逆向折叠PDB/mmCIF、FASTAFASTA、CSV 评分文件ESM-IF1、sample_sequences.py
宏基因组搜索FASTA、结构文件匹配结构列表、嵌入数据ESM Atlas API
MSA 分析A3M 比对文件接触图、进化特征ESM-MSA-1b
生成式设计自然语言/编程指令氨基酸序列ESM-2
### 使用ESM分类模型的相关信息 #### 什么是ESM分类模型ESM(Evolutionary Scale Modeling)是一系列用于蛋白质序列建模的深度学习框架,主要由Meta AI开发。这些模型利用Transformer架构来捕捉蛋白质序列中的复杂模式和关系[^3]。其中,ESM-1b、ESM-2模型被广泛应用于蛋白质的功能预测、结构预测以及突变效应分析等领域。 #### 如何使用ESM分类模型? ##### 安装依赖库 为了使用ESM模型,首先需要安装必要的Python包。可以通过pip命令完成安装: ```bash pip install esm biotite torch ``` ##### 加载预训练模型 以下是加载并运行ESM-2模型的一个简单示例代码: ```python import esm import torch # 加载预训练模型 ESM-2 650M 参数版本 model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() batch_converter = alphabet.get_batch_converter() # 设置设备为GPU或CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # 输入待处理的蛋白质序列 data = [ ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"), ] # 将输入转换为批量张量形式 batch_labels, batch_strs, batch_tokens = batch_converter(data) batch_tokens = batch_tokens.to(device) # 运行前向传播获取表示 with torch.no_grad(): results = model(batch_tokens, repr_layers=[33], return_contacts=True) # 提取最后一层的隐藏状态作为蛋白质序列的表示 token_representations = results["representments"][33] print(token_representations.shape) # 输出形状应为 (batch_size, sequence_length, embedding_dim) ``` 上述代码展示了如何加载`esm2_t33_650M_UR50D`模型,并提取给定蛋白质序列的最后一层嵌入表示。此表示可以进一步用于下游任务,如分类或回归。 ##### 应用场景 1. **零样本预测** 基于大规模无监督学习ESM模型能够在未见过的数据上执行零样本预测。例如,《Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences》一文中提到,ESM-1v模型具备强大的零样本预测能力,可用于评估氨基酸突变的影响[^4]。 2. **蛋白质功能注释** 利用ESM模型生成的序列表示,可以直接映射到特定的功能类别。这种技术已被证明在多个基准测试中表现优异。 3. **蛋白质结构预测** ESM模型还可以与其他工具(如AlphaFold3)结合,提供高质量的初始特征表示,从而提高整体性能。 #### 实现细节与注意事项 - 数据准备阶段需注意标准化氨基酸序列格式。 - 对于较大的模型实例(如ESM-2 15B),建议使用高性能计算环境以减少内存占用和运算时间。 - 如果目标是微调现有模型,则可通过PyTorch接口轻松实现迁移学习策略。 --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值