关于ESMC-6B蛋白质语言大模型API的使用(ESM3/ESMC简介、batch序列数据输入尝试)

前言

ESM2模型仓库:facebookresearch/esm: Evolutionary Scale Modeling (esm): Pretrained language models for proteins​​​​​​

ESM3/ESMC模型仓库:

evolutionaryscale/esm

ESM3模型其实在今年年初就已经开源了,但是官方称ESM3主要是“专注于在医疗与其他应用方面的蛋白质的可控生成”,并且主要训练方法是mask prediction,所以用它来做序列嵌入生成的文章是比较少的(应该是)。

但今年的12月4号,ESMC模型开源了,并且开放了线上调用的接口,官方称ESMC是“专注于获取蛋白质隐藏生物信息的表征”的,是ESM2的上位替代,还列出了ESMC-600M可以上位替代ESM2-3B,表现和效率都优于ESM2,这下就不得不体验一把了。

替代表(官方认为的):

ESMCESM2

ESMC表现

300M650M相近
600M3B/15B

上位替代

接近ESM2-15B

6B-远超ESM2

官方开源了ESM3、ESMC的代码,还开源了ESMC-300M和600M的参数,可以本地部署。

但既然有目前最优的蛋白质序列表征模型ESMC-6B,那肯定还是得试一试效果的。

ESMC-6B模型API的使用

首先pip install esm

然后官方给出了调用的代码:

from esm.sdk.forge import ESM3ForgeInferenceClient
from esm.sdk.api import ESMProtein, LogitsConfig

# Apply for forge access and get an access token
forge_client = ESM3ForgeInferenceClient(model="esmc-6b-2024-12", url="https://forge.evolutionaryscale.ai", token="<your forge token>")
protein_tensor = forge_client.encode(protein)
logits_output = forge_client.logits(
   protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)
print(logits_output.logits, logits_output.embeddings)

但是这个代码需要获取到ESMC的使用token权限,需要申请,而且有点麻烦

https://forge.evolutionaryscale.ai.

有幸拿到访问许可之后就可以体验了。

ESMC-6B通过调用logits输出的embedding是一个 [1,序列长度,2560] 的tensor

一条序列嵌入所需的时间约为4~5s

如果你只需要跑某个物种的蛋白质数据集可以直接一条一条的查询,大概只用1~3天就能跑完

但如果要跑swissprot数据集(50w+),那可能需要跑上一个月,更别说还要对序列做mask数据增强。

并且输出的所占的空间很大,一般来说需要做一个平均池化或者什么池化

所以我就开始尝试能够一次性跑完一组序列的方法。

漫长的issue、源码、文档阅读

Issue1

首先是看到了这个issue

About Generating Protein Sequence Embeddings with Your Model · Issue #2 · evolutionaryscale/esm

这是开发者提出的一种方法

其实并没有解决batch预测的问题,只是利用了一下python语法的特性

本质上就是把多个序列拼接成一个,然后去跑模型

并且在调用encode API得到的序列依然只有一个<cls>和一个<eos>

没有根据序列来做tokenize

更加困惑的是,我们并不知道,如果将序列编码成

<cls>sequence1<eos><cls>sequence2<eos>...

这样的形势之后,ESMC的输出是否会受到不同序列间的数据特征的影响

毕竟它也是基于transformer的模型,n^2计算的自注意力层或多或少会导致不同序列间的影响。

(不过也可以试试看,后文我将API的encode()编码逆向出来了)

Issue2

然后就看到了这个Issue:make a batch to process multi protein at one time. · Issue #19 · evolutionaryscale/esm

虽然这个是针对ESM3生成式模型提出的一种batch方法

当时我也抱着死马当做活马医的态度试了一下

但是是过不了encode()的,因为encode()里面传入的内容只能是ESMProtein,不能使用其他的数据结构。

而batch_generate()输出是Sequence[ProteinType]

源码与文档分析

成功获取token后,可以看到reference文档,但是里面只写了http传输的数据包格式,并且再一次让我确定了,encode()和logits()都只能接受一个蛋白质序列。

但是根据源码,我们可以看出它是使用json传输request数据的,request里面就包含了我们要传过去的序列数据

源码中可以看出encode的输入格式已经被限制死了,ESMProtein中的seq数据只能是str

但是logits中传入的ESMProteinTensor还可以改变其中sequence的tensor形状

于是我决定尝试修改传入的protein_tensor.sequence的形状:

from esm.sdk.forge import ESM3ForgeInferenceClient
from esm.sdk.api import ESMProtein, LogitsConfig
import torch

protein = ESMProtein(sequence='ACDEFGHIKLMNPQRSTVWY_')

forge_client = ESM3ForgeInferenceClient(model="esmc-6b-2024-12", url="https://forge.evolutionaryscale.ai", token="")
protein_tensor = forge_client.encode(protein)
print(protein_tensor.sequence.shape)
protein_tensor.sequence = torch.stack([protein_tensor.sequence])
print(protein_tensor.sequence.shape)
logits_output = forge_client.logits(
   protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)
print(logits_output)
print(logits_output.logits, logits_output.embeddings.shape)

但很不幸的是报错了,这次并不是encode()或者logits()的报错,而是网页端返回的报错,我们绕过了本地的所有数据类型限制,但是网页并不允许这样的tensor输入,以下是报错信息:

具体报错信息解释就是input.sequence[0]必须是一个合法的整数,而不是tensor

而我们改变形状之后,input.sequence[0]是第一个序列对应的tensor

于是

再看了一下ESMC模型的代码,发现模型主体部分是完全支持[....,.....,序列长度]的tensor输入的

logits函数调用的就是ESMC的forward

模型里面也没有多少玄学的操作,就一个Embedding、Transformer、regression

embedding不用说了,TransformerStack和RegressionHead(线性层+GELU+LN+线性)都是可以接入多维输入的

按道理来说使用ESMC模型端处理多维数据是没有任何问题的

这里就排除了ESMC模型端的报错

既然不是本地调用的问题,也不是远端模型的问题,最有可能的是服务器http解包的时候判断了读入数据的格式,避免了用户传入batch数据。

以上也只是我的猜想,但毕竟再怎么说server端代码肯定是不会开源的。

优化思路

尝试减小http包优化网络延迟

虽然说ESMC支持返回值采样forward_and_sample(),可以设置samplingconfig,直接获取蛋白质的平均值池化,显著地减小了http传输的数据量,但是这个功能在ESMC-6B模型中并不存在!以下是尝试使用的报错:

真是堵住了所有的路啊。。。。

尝试优化encode()减少http传输次数

最终成功的就只有这个了,但也优化不了多少,5秒一个序列查询优化到4.7秒左右

原理是通过检查encode()调用后返回的http报文,可以发现是有规律的

其实ESMC-6B远端模型使用的encode()和EsmSequenceTokenizer.tokenize()只有“_”的编码不同(encode是32,tokenize是3),其他的起始符<cls>:0, 占位符<pad>:1, 结尾符<eos>:2

all_amino_acid_number = {'A':5, 'C':23,'D':13,'E':9, 'F':18,
                         'G':6, 'H':21,'I':12,'K':15,'L':4,
                         'M':20,'N':17,'P':14,'Q':16,'R':10,
                         'S':8, 'T':11,'V':7, 'W':22,'Y':19,
                         '_':32}
def esm_encoder_seq(seq, pad_len):
    s = [all_amino_acid_number[x] for x in seq]
    while len(s)<pad_len:
        s.append(1)
    s.insert(0,0)
    s.append(2)
    return torch.tensor(s)

这样就在本地进行encode(),减少一次http的RTT时间

但由于传输的数据量本来就很少,这个也优化不了多少时间,logits占的数据量才是最大的,结果它又不让优化。

结尾吐槽

拼尽全力,依然无法战胜

没能让ESM团队用出全力,真是遗憾呢

同学,你也不想reviewer知道你2025年还在用ESM-1b吧(其实还挺好用的)

但其实大家也不必着急(跑一个月也不是不能接受(?,毕竟一劳永逸,特征传三代,人走数据在),ESM团队正在开发ESMC-6B的batch输入功能了

在esm.utils.sampling中,有个类叫_BatchedESMProteinTensor,这个应该是以后会使用的数据传输格式,但是现在直接使用依然会报错,需要等团队修改server端代码的检查才能使用。

另外,本地也可以部署ESMC-600M的模型,而且是ESM2-3B的上位替代,ESMC-6B只是一个锦上添花的模型啦。

评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值