生物大模型文献及代码精读(三)找到所有物种的通用基因?
今天给大家分享的文献来自于斯坦福大学计算机科学系、瑞士联邦理工学院的计算机与通信科学学院和清华大学计算机科学与技术系三家单位(居然没有生物相关单位,计算机人都来研究生物了?)合作的单细胞数据中的基因embedding大作Toward universal cell embeddings: integrating single-cell RNA-seq datasets across species with SATURN。
文章内容梳理
摘要简介
做了什么? 在这里我们介绍SATURN(翻译为土星),一种基于蛋白质语言模型的通用细胞嵌入编码基因的学习方法。通过整合来自不同物种的数据集及基因组相似性,作者提出了一种Macrogenes,用于综合不同基因跨物种共表达;作者基于这种embedding方法,将其用于单细胞多物种整合等下游任务,效果拔群;同时还展现了识别位置基因的功能的作用。
意义是什么?
-
解决跨物种分析难题
-
增强对细胞类型的理解:SATURN方法通过整合不同物种的单细胞RNA测序数据,揭示了细胞类型在进化过程中的保守性和多样性。
-
创新的生物信息学工具:文章介绍了一种创新方法,利用蛋白质语言模型生成的嵌入来表示基因,为跨物种的分子相似性提供了新的视角。
-
发现跨物种共享的基因程序:SATURN通过对大基因(macrogenes)进行差异表达分析,能够识别出功能相关的基因模块,这些模块跨越了物种界限,有助于识别和理解跨物种共享的生物学过程。
话外 :感觉有一种WGCNA套皮重生的意思,原来咱们的WGCNA不就是用基因模块对应的这里的Macrogenes的吗?更何况这里用的Macrogenes就是用Kmeans做出来的,WGCNA还用了指数来强化基因聚类。但是不同的是,WGCNA没有关注到基因序列的信息,只是矩阵表达的信息。 所以在蛋白质大语言模型的加持下,还是发到了Nature Methods
WGCNA分析
文章结果速览
模型总览
核心挑战: 跨物种整合的主要难点在于不同物种的基因数据集中含有不完全同源的基因,如果仅选取那些有一对一同源物的基因,会导致大量生物学上有意义的基因信息丢失。
SATURN的解决方案: 为了解决上述问题,SATURN采用大型蛋白质语言模型来学习细胞嵌入(cell embeddings),通过蛋白质嵌入将不同物种的scRNA-seq(单细胞RNA测序)数据集映射到一个基于功能相关性的低维共享空间中。过程包括: 输入scRNA-seq的count、大型蛋白质嵌入语言模型(如ESM2)产生的蛋白质嵌入,以及种内的细胞注释。
同时SATURN学习到了一个可解释的、多物种间共享的特征空间,即macrogene space。在这个空间中,基因被推断为功能相关的,即使它们在序列基础上并不明显同源。
应用效果: SATURN成功地在不同物种间转移了细胞类型的注释,发现了同源性和物种特有的细胞类型,且性能优于现有的跨物种整合方法。
SATURN架构
一、基于SATURN构建的多物种单细胞图谱
多物种数据整合: SATURN利用深度学习技术,将不同物种的单细胞RNA测序(scRNA-seq)数据集整合起来。它通过将基因表达与由大型蛋白质语言模型生成的蛋白质嵌入相结合,成功地创建了包含人类、鼠狐猴、小鼠等多个物种的哺乳动物细胞图谱,总计覆盖了335,000个细胞,横跨九个常见组织。 这种方法不仅限于哺乳动物,还应用于青蛙和斑马鱼的胚胎发育数据集,展示了其在进化关系较远物种间的应用潜力。
宏观基因(Macrogenes)概念:SATURN提出了一种“宏观基因”的概念,即将具有相似蛋白质嵌入的基因分组。通过学习基因与这些宏观基因之间的关联强度,SATURN能够捕捉到功能上相关的基因群,即便它们在不同物种中的基因序列可能不直接同源。这种方法有助于识别和分类那些在功能上相似但基因表达模式在物种间有所差异的细胞类型。
细胞类型标注的校正与转移:SATURN能够重新注释细胞类型并纠正不完整的注释。例如,在对包含人类、恒河猴、猕猴、小鼠和猪的眼部细胞图谱(AH图谱)进行整合时,SATURN揭示了一些细胞类型如色素细胞、巨噬细胞和纤毛肌在所有物种中的一致性排列,同时也识别出了仅在部分物种中存在的细胞类型,如成纤维细胞。通过对原始注释的重新分组,SATURN修正了如成纤维细胞和角膜内皮细胞等的分类,发现并纠正了原有注释中的错误,如将原标记为巨噬细胞的小鼠细胞重新归类为成纤维细胞,并通过表达特定标志基因进一步证实了这一重分类的准确性。
二、基于macrogenes做差异分析
-
通过将每个细胞内基因的表达值乘以相应的基因-宏观基因权重,然后加权求和并规范化得到的。这使得即使在不同物种间缺乏直接基因同源的情况下,也能比较细胞间的表达模式。
-
完成差异表达分析后,SATURN通过识别对每个宏基因贡献最大的基因来解读其生物学意义。 这些高权重基因往往代表了该宏基因功能的关键组成部分。
三、基于macrogenes捕捉基因间的序列同源性以及功能相似性
-
宏基因同源性的捕获:SATURN生成的宏基因能够重新捕捉基于序列的基因同源关系。研究团队通过用BLASTP计算在斑马鱼和青蛙之间宏基因中排名靠前的基因对中同源基因的比例发现,即使仅考虑每个物种排名最靠前的基因,SATURN的宏基因中有56%能够重新识别出同源信息。而当考虑每个物种的前十名排名基因时,这一比例上升到了91.2%。
-
超越序列同源的功能相似性:宏基因不仅能识别基于序列的同源基因,还能揭示那些通过常规序列比对工具未被认定为同源,但功能上存在相似性的基因。 通过基因本体(GO)分析,同一宏基因的基因集显示出显著富集的GO功能相似性。
四、SATURN在跨物种细胞数据集分析方面的优势
基于对青蛙和斑马鱼胚胎发育数据集的对齐来评估能否有效地将一个参考物种(斑马鱼)的细胞类型标签转移到查询物种(青蛙)上,这一任务准确率达到93%,且显著高于其他任务。 这一卓越表现的背后是SATURN对细胞嵌入的优化策略,包括使用预训练损失函数来优化模型,并采用微调阶段的弱监督度量学习目标来自动学习跨物种的距离度量。这在后续的代码模块我们会着重介绍。
日常货比三家环节
五、整合不同物种的细胞图谱
SATURN成功整合了五种物种的数据,展示了其在跨物种数据整合方面的强大能力。通过对涵盖人类、猪、小鼠、恒河猴和狨猴的细胞图谱数据进行分析,SATURN不仅能够识别出跨物种间保守的细胞类型,还能重新注释细胞类型,甚至在进化关系较远的物种间有效转移注释。
六、预测同源基因的不同功能
SATURN还展示了一项关键能力,即预测同源基因间的不同功能。以青光眼为例,研究发现人类的MYOC基因并未与其在猪、小鼠、恒河猴和狨猴中的同源基因归为同一宏基因(macrogene)。宏基因是由蛋白质嵌入模型识别出的功能相似的基因集合。 人类的MYOC基因反而与非同源基因A2M(已知也与青光眼有关)以及其他几个非人类物种的基因紧密关联。这表明,尽管这些基因在序列上可能相似,但它们在不同物种中的功能可能存在显著差异。
总结
SATURN的本质是将基因的序列语言信息,以macrogene的维度,加入到了单细胞聚类和信息整合中,我们原来将一个单细胞矩阵中的不同基因在矩阵计算中作为权重相等的等价地位,但是实际上它们并不是等价的,所以我们的分析实际上是有局限的,本文通过不同基因对于macrogene的不同权重来解决了这一问题,虽然借鉴了WGCNA的基因聚块思想,但是实际上是从另一个角度来关注了这个问题。
从另一方面说,或许每种动物被设计/进化出来的过程中,背后都遵循着一套程序,但是这套程序在物质世界的表现形式却千差万别。就像电脑不同的操作系统,最后实现的功能是大童同小异。“善行无辙迹,善言无瑕谪,善数不用筹策,善闭无关楗而不可开,善结无绳约而不可解。”。正如这篇文章中的macrogene作为大多数基因的代表,是这种背后程序的一种体现。真正的真理往往隐藏在无形的规律组合中。
文章模型架构解析及代码梳理
模型架构以及对应模型代码解读
本文的核心是将多物种表达数据映射到一个联合的低纬macrogene的表达空间中,这个Macrogene的原型,也就是训练的模型架构是来自soft-clustering protein embeddings 。具体步骤如下:
-
蛋白质嵌入生成:首先,为所有基因生成蛋白质嵌入向量。这通常通过应用预先训练好的蛋白质语言模型完成,例如ESM2,该模型接受氨基酸序列并输出一个高维向量表示蛋白质。
-
K-means聚类: 然后,利用k-means算法对这些蛋白质嵌入进行聚类,以确定一组质心M,这里的质心代表不同的macrogenes。
-
基因到macrogene的初始权重:对于每个基因g到macrogene m的初始权重Wg,m,使用如下公式计算: 其中,rdm,g是基因g到macrogene m的欧几里得距离排名,且当g是最接近m的基因时rdm,g=1。这个初始化确保了基因与它们最近的macrogene之间的权重较高。
-
随后作者进行预训练,包括编码和解码两个模块,以及损失函数的定义。编码模块根据基因表达和macrogene权重生成细胞的低维表示,解码模块则试图从这个表示重构原始数据。
-
编码: 细胞c的macrogene表达值ec由基因表达矩阵Xsc与macrogene权重WTs相乘后经过ReLU和LayerNorm激活得到,公式为:
-
解码: 解码模块输出三个参数,分别对应负二项分布(ZINB)的均值μc、偏移量Oc和形状参数θ,通过全连接层计算得到。
-
损失函数: 包括两部分,一部分是基于ZINB分布的重构损失ℒrc,另一部分是确保基因到macrogene权重反映蛋白质嵌入空间相似性的额外损失ℒs。其中重构损失ℒrc是负对数似然函数,基于ZINB分布参数化,因为ZINB分布特别适用于存在过多零计数的情况,需要将损失改成这样: 以及额加的损失项ℒs确保基因到macrogene权重与蛋白质嵌入相似度一致,通过比较打乱前后的相似度计算MSE(均方误差):
saturn_model.py
让我来看看作者的模型代码是怎么写的,如何与以上公式对应
def __init__(self, gene_scores, dropout=0, hidden_dim=128, embed_dim=10, species_to_gene_idx={}, vae=False, random_weights=False,
sorted_batch_labels_names=None, l1_penalty=0.1, pe_sim_penalty=1.0):
super().__init__()
self.dropout = dropout
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim
self.sorted_batch_labels_names = sorted_batch_labels_names
if self.sorted_batch_labels_names is not None:
self.num_batch_labels = len(sorted_batch_labels_names)
else:
self.num_batch_labels = 0
self.num_gene_scores = len(gene_scores)
self.num_species = len(species_to_gene_idx)
self.species_to_gene_idx = species_to_gene_idx
self.sorted_species_names = sorted(self.species_to_gene_idx.keys())
self.num_genes = 0
self.vae = vae
for k,v in self.species_to_gene_idx.items():
self.num_genes = max(self.num_genes, v[1])
self.p_weights = nn.Parameter(gene_scores.float().t().log())
if random_weights: # for the genes to centroids weights
nn.init.xavier_uniform_(self.p_weights, gain=nn.init.calculate_gain('relu'))
self.num_cl = gene_scores.shape[1]
self.cl_layer_norm = nn.LayerNorm(self.num_cl)
self.expr_filler = nn.Parameter(torch.zeros(self.num_genes), requires_grad=False) # pad exprs with zeros
if self.vae:
# Z Encoder
self.encoder = nn.Sequential(
full_block(self.num_cl, hidden_dim, self.dropout),
)
self.fc_var = nn.Linear(self.hidden_dim, self.embed_dim)
self.fc_mu = nn.Linear(self.hidden_dim, self.embed_dim)
else:
self.encoder = nn.Sequential(
full_block(self.num_cl, self.hidden_dim, self.dropout),
full_block(self.hidden_dim, self.embed_dim, self.dropout),
)
# Decoder
self.px_decoder = nn.Sequential(
full_block(self.embed_dim + self.num_species + self.num_batch_labels, self.hidden_dim, self.dropout),
)
self.cl_scale_decoder = full_block(self.hidden_dim, self.num_cl)
self.px_dropout_decoders = nn.ModuleDict({
species: nn.Sequential(
nn.Linear(self.hidden_dim, gene_idxs[1] - gene_idxs[0])
) for species, gene_idxs in species_to_gene_idx.items()}
)
self.px_rs = nn.ParameterDict({
species: torch.nn.Parameter(torch.randn(gene_idxs[1] - gene_idxs[0]))
for species, gene_idxs in species_to_gene_idx.items()}
)
self.metric_learning_mode = False
# Gene to Macrogene modifiers
self.l1_penalty = l1_penalty
self.pe_sim_penalty = pe_sim_penalty
self.p_weights_embeddings = nn.Sequential(
full_block(self.num_cl, 256, self.dropout) # This embedding layer will be used in metric learning to encode
# similarity in the protein embedding space
)
参数: gene_scores: 基因重要性分数,用于初始化权重矩阵。
dropout, hidden_dim, embed_dim: 模型超参,分别控制Dropout比率、隐藏层大小和嵌入维度。
species_to_gene_idx: 物种到其基因索引范围的映射,用于处理多物种数据。
vae: 是否使用变分自编码器结构。
random_weights, sorted_batch_labels_names, l1_penalty, pe_sim_penalty: 控制权重初始化、批次标签处理及正则化项的参数。
模型基本模块有:
-
Layer Normalization (LayerNorm): self.cl_layer_norm = nn.LayerNorm(self.num_cl) 用于对输入特征进行归一化,提高模型训练的稳定性和速度。
-
Linear Layers:
-
Sequential Container:这俩个不解释了
-
Parameter: nn.Parameter 用于定义模型中需要学习的参数,例如权重矩阵 self.p_weights 和随机初始化的参数 self.expr_filler。
-
ModuleDict: nn.ModuleDict 在 self.px_dropout_decoders 中使用,创建了一个字典,其中键为物种名称,值为特定于物种的解码器网络,使得模型可以根据输入物种动态选择合适的解码路径。
-
ParameterDict: nn.ParameterDict 在 self.px_rs 中使用,为每个物种定义了一个随机初始化的参数,用于表达式重构中的随机效应或噪声参数。
-
Custom Block (full_block): 网络的基本构建自定义块,1个线性、1个layerN、1个relu、1个dropout
-
Embedding Layer: self.p_weights_embeddings 是一个嵌入层,用于将簇(Cluster)相关的权重转换到另一个特征空间,特别用于度量学习任务中编码蛋白质嵌入空间的相似性。
我们可以发现,genescores输入后通过这里的nn.ParameterDict模块进行随机初始化,衍生出self.p_weights和self.num_cl
self.p_weights = nn.Parameter(gene_scores.float().t().log())
if random_weights: # for the genes to centroids weights
nn.init.xavier_uniform_(self.p_weights, gain=nn.init.calculate_gain('relu'))
self.num_cl = gene_scores.shape[1]
而self.num_cl则被用于后续的encoder中基础模块full_block的构筑
if self.vae:
# Z Encoder
self.encoder = nn.Sequential(
full_block(self.num_cl, hidden_dim, self.dropout),
)
self.fc_var = nn.Linear(self.hidden_dim, self.embed_dim)
self.fc_mu = nn.Linear(self.hidden_dim, self.embed_dim)
else:
self.encoder = nn.Sequential(
full_block(self.num_cl, self.hidden_dim, self.dropout),
full_block(self.hidden_dim, self.embed_dim, self.dropout),
)
同时,我们发现p_weights_embeddings是用与编码蛋白语言大模型计算出的相似度的数据的
self.p_weights_embeddings = nn.Sequential(
full_block(self.num_cl, 256, self.dropout) # This embedding layer will be used in metric learning to encode
# similarity in the protein embedding space
)
而在forward方法中,我们可以看到通过expr矩阵和即in权重,进行encoder中的计算,这里的前20行就是对应encoder的计算公式的
batch_size = inp.shape[0]
# Pad the appened expr with 0s to fill all gene nodes
expr = torch.zeros(batch_size, self.num_genes).to(inp.device)
filler_idx = self.species_to_gene_idx[species]
expr[:, filler_idx[0]:filler_idx[1]] = inp
expr = torch.log(expr + 1)
# concatenate the gene embeds with the expression as the last item in the embed
expr = expr.unsqueeze(1)
# GNN and cluster weights
clusters = []
expr_and_genef = expr
x = nn.functional.linear(expr_and_genef.squeeze(), self.p_weights.exp())
x = self.cl_layer_norm(x)
x = F.relu(x) # all pos
x = F.dropout(x, self.dropout)
encoder_input = x.squeeze()
encoded = self.encoder(encoder_input)
if self.vae:
# VAE
mu = self.fc_mu(encoded)
log_var = self.fc_var(encoded)
encoded = self.reparameterize(mu, log_var)
else:
mu = None
log_var = None
spec_1h = torch.zeros(batch_size, self.num_species).to(inp.device)
#spec_idx = np.argmax(np.array(self.sorted_species_names) == species) # Fix for one hot
spec_idx = 0
spec_1h[:, spec_idx] = 1.
if self.num_batch_labels > 0:
# construct the one hot encoding of the batch labels
# also a categorical covariate
batch_1h = torch.zeros(batch_size, self.num_batch_labels).to(inp.device)
batch_idx = np.argmax(np.array(self.sorted_batch_labels_names) == batch_labels)
batch_1h[:, batch_idx] = 1.
spec_1h = torch.hstack((spec_1h, batch_1h)) # should already be one hotted
if encoded.dim() != 2:
encoded = encoded.unsqueeze(0)
if self.metric_learning_mode:
# Return Encoding if in metric learning mode (encoder only)
return encoded
其中:
-
Input Preprocessing: 首先,对输入基因表达矩阵inp进行预处理,包括填充零值以匹配所有基因节点、对数值进行对数变换,然后将其转换为适合模型处理的形状。
-
Linear Transformation & Activation: 使用线性层(权重为self.p_weights.exp())变换处理后的表达数据,并应用层归一化(LayerNorm)、ReLU激活以及Dropout,这一系列操作构成了一种特征提取过程,准备数据进入编码器。
-
Encoding: 接着,通过调用self.encoder(encoder_input)对处理后的特征进行编码,得到一个潜在表示(latent representation)。这个步骤是降维处理,目的是捕捉输入数据的主要结构信息。
decoded = self.px_decoder(torch.hstack((encoded, spec_1h)))
library = torch.log(inp.sum(1)).unsqueeze(1)
# modfiy
cl_scale = self.cl_scale_decoder(decoded) # num_cl output
# index genes for mu
idx = self.species_to_gene_idx[species]
cl_to_px = nn.functional.linear(cl_scale.unsqueeze(0), self.p_weights.exp().t())[:, :, idx[0]:idx[1]]
# distribute the means by cluster
px_scale_decode = nn.Softmax(-1)(cl_to_px.squeeze())
px_drop = self.px_dropout_decoders[species](decoded)
px_rate = torch.exp(library) * px_scale_decode
px_r = torch.exp(self.px_rs[species])
而其中的decoder将从encoder中计算出decoder
-
Conditioning Information: 添加物种和批次标签的独热编码(one-hot encoding)到编码器的输出中,作为解码时的条件信息。 这允许模型基于特定的物种或批次信息进行个性化生成。
-
Decoding: 调用self.px_decoder解码器,输入是编码器输出与条件信息的拼接,用于生成重构的基因表达数据。这是从低维潜在空间回到原始或近似原始数据空间的过程。
-
Cluster Scaling and Dropout: 之后,代码涉及一些特定于cluster的操作,如计算每个簇的缩放因子(cl_scale),以及利用这些缩放因子和dropout解码器对解码结果进行调整,以最终确定每个基因的表达率预测(px_rate)
接下来是loss,对应于这部分公式
首先是ZINO分布的loss:
def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout):
'''https://github.com/scverse/scvi-tools/blob/master/scvi/module/_vae.py'''
return -ZeroInflatedNegativeBinomial(
mu=px_rate, theta=px_r, zi_logits=px_dropout
).log_prob(x).sum(dim=-1)
def gene_weight_ranking_loss(self, weights, embeddings):
# weights is M x G
x1 = self.p_weights_embeddings(weights.t())
# genes x 256
loss = nn.MSELoss(reduction="sum")
similarity = torch.nn.CosineSimilarity()
idx1 = torch.randint(low=0, high=x1.shape[0], size=(x1.shape[0],))
x2 = x1[idx1, :]
target = similarity(embeddings, embeddings[idx1, :])
return loss(similarity(x1, x2), target)
项目运行流程
先安装一手
pip install torch==1.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
pip install -r requirements.txt
有一点小bug这里应该先装torch==1.10.2+cu113,再装requirments
作者提供的示例: 首先是用作者给的脚本用scanpy预处理不同物种的单细胞数据
!wget -O ./data/WagnerScience2018.h5ad https://kleintools.hms.harvard.edu/paper_websites/wagner_zebrafish_timecourse2018/WagnerScience2018.h5ad
用scanpy预处理
zebrafish = sc.read(os.path.join(loc,'WagnerScience2018.h5ad'))
zebrafish.obs['cluster'] = pd.Categorical([z[6:] if '-' in z else z for z in zebrafish.obs['ClusterName']])
with open(os.path.join(loc,'zebrafish_cell_types_mapping')) as f:
cell_types_mapping = f.readlines()
ct_map = {}
for line in cell_types_mapping[1:]:
el = line.split("\t")
ct_map[el[0].strip()] = el[1].strip()
ct_map['periderm'] = 'Periderm'
ct_map['pluripotent'] = 'Pluripotent'
ct_map['neural - floorplate posterior'] = 'Notoplate'
ct_map['neural crest - mcamb'] = 'Neural crest'
ct_map['neural crest - melanoblast'] = 'Neural crest'
ct_map['neural crest - iridoblast'] = 'Neural crest'
ct_map['neural crest - xanthophore'] = 'Neural crest'
ct_map['neural crest - crestin'] = 'Neural crest'
#load zebrafish data
zebrafish = sc.read(os.path.join(loc,'zebrafish_annot.h5ad'))
zebrafish.obs#再看一眼
zebrafish.obs_names = [x for x in zebrafish.obs_names ]
zebrafish.var_names = [x for x in zebrafish.var_names ]
zebrafish.obs.cell_type = [x for x in zebrafish.obs.cell_type ]
sc.pp.filter_cells(zebrafish, min_genes=500)
sc.pp.filter_genes(zebrafish, min_cells=10)
zebrafish.X.toarray().max()
zebrafish.write(os.path.join(loc, "zebrafish.h5ad"))
zebrafish
#最后拿到的矩阵如下AnnData object with n_obs × n_vars = 63371 × 30032
obs: 'n_counts', 'unique_cell_id', 'cell_names', 'library_id', 'batch', 'ClusterID', 'ClusterName', 'TissueID', 'TissueName', 'TimeID', 'cluster', 'cell_type', 'n_genes'
var: 'n_cells'
# Make the csv
import pandas as pd
df = pd.DataFrame(columns=["path", "species", "embedding_path"])
df["species"] = ["frog", "zebrafish"]
df["path"] = ["data/frog.h5ad", "data/zebrafish.h5ad"]
##### CHANGE THESE PATHS #####
frog_embedding_path = "/media/ubuntu/20TB/Project/Deeplearning/GPT/pythonProject/SATURN-main/protein_embeddings/proteome/Xenopus_
tropicalis.UCB_Xtro_10.0.gene_symbol_to_embedding_ESM1b.pt"
zebrafish_embedding_path = "/media/ubuntu/20TB/Project/Deeplearning/GPT/pythonProject/SATURN-main/protein_embeddings/proteome/Danio_rerio.GRCz11.gene_symbol_to_embedding_ESM1b.pt"
##############################
df["embedding_path"] = [frog_embedding_path, zebrafish_embedding_path]
df.to_csv("data/frog_zebrafish_run.csv", index=False)
df
pd.read_csv("data/frog_zebrafish_cell_type_map.csv").head(10)
这里的/dfs/project/cross-species/yanay/data/proteome/embeddings/Xenopus_tropicalis.Xenopus_tropicalis_v9.1.gene_symbol_to_embedding_ESM2.pt需要先预训练一下,这里用了https://github.com/facebookresearch/esm里面的ESM-1b Transformer,用于从基因序列中生成embeddings,这里用了5G显存跑了5分钟左右。
这里用Xenopus_tropicalis演示一下作者是怎么作embedding的
cd .../proteome
wget https://ftp.ensembl.org/pub/release-108/fasta/xenopus_tropicalis/pep/Xenopus_tropicalis.UCB_Xtro_10.0.pep.all.fa.gz
gunzip Xenopus_tropicalis.UCB_Xtro_10.0.pep.all.fa.gz
python ../clean_fasta.py \
--data_path Xenopus_tropicalis.UCB_Xtro_10.0.pep.all.fa \
--save_path Xenopus_tropicalis.UCB_Xtro_10.0.pep.all_clean.fa
git clone git@github.com:facebookresearch/esm.git
#生成embedding
python ../esm-main/scripts/extract.py \
esm1b_t33_650M_UR50S \
Xenopus_tropicalis.UCB_Xtro_10.0.pep.all_clean.fa \
Xenopus_tropicalis.UCB_Xtro_10.0.pep.all_clean.fa_esm1b \
--include mean \
--truncate
python extract.py \
esm1b_t33_650M_UR50S \
.../proteome/Homo_sapiens.GRCh38.pep.all_clean.fa \
../proteome/embeddings/Homo_sapiens.GRCh38.pep.all_clean.fa_esm1b \
--include mean \
--truncate
python ../map_gene_symbol_to_protein_ids.py --fasta_path Xenopus_tropicalis.UCB_Xtro_10.0.pep.all.fa --save_path Xenopus_tropicalis.UCB_Xt
ro_10.0.gene_symbol_to_protein_ID.json
python ../convert_protein_embeddings_to_gene_embeddings.py --embedding_dir Xenopus_tropicalis.UCB_Xtro_10.0.pep.all_clean.fa_esm1b/ --ge
ne_symbol_to_protein_ids_path Xenopus_tropicalis.UCB_Xtro_10.0.gene_symbol_to_protein_ID.json --embedding_model ESM1b --save_path Xen
opus_tropicalis.UCB_Xtro_10.0.gene_symbol_to_embedding_ESM1b.pt
同样的计算一下另一个斑马鱼的
数据准备好了之后,随后开始训练,这里device_num写你自己的显卡数量,比如2块就写1,从0开始计算
python3 ../../train-saturn.py --in_data=data/frog_zebrafish_run.csv \
--in_label_col=cell_type --ref_label_col=cell_type \
--num_macrogenes=2000 --hv_genes=8000 \
--centroids_init_path=saturn_results/fz_centroids.pkl \
--score_adata --ct_map_path=data/frog_zebrafish_cell_type_map.csv \
--work_dir=. \
--device_num=1