生物大模型文献及代码精读(一)scGPT——3000万细胞的预训练模型?

生物大模型文献及代码精读(一)

现如今,大模型的风刮到了生物领域,单细胞领域是第一个吃到大模型红利的,所以准备开一个专栏,一起和各位生信人一起涨大模型知识!今天分享的文章是最热乎的生成式大模型scGPT,话不多说上文献。

今天的文献来自于加拿大多伦多大学Peter Munk心脏中心Bo Wang研究团队的成员在Nature Methods杂志上的大作。

文章内容梳理

摘要简介

做了什么? 建立了一个基于超过3300万个单细胞数据细胞的生成式预训练transformer,主要同时学习细胞和基因的表达。

意义是什么? 促进预训练模型在各种不同任务中的应用,如细胞类型注释、基因扰动预测、批次校正和多组学集成等方面,展最终实现“通用预训练,按需微调”。

话外:预训练模型到底是什么? 最简单的理解:自己练了一套花拳绣腿,但是自己修为不够,发挥不了这招式的威力,所以找高手传输内力,使得自己的花拳绣腿也威力强大起来了。 这里的预训练就是高手的“内力”,是花了很大功夫(数据量)炼成的,能够通用的“放大“武功威力(模型效果)的东西。

用科研的话来说就算,由于我们很多项目没有大数据支持(小数据),比如猫狗分类任务:只有100 张猫和狗的图片(无法解决的一个问题,精度很低)但是我可以通过100000 张鹅和鸭的图片(已知,有人做过的,通过这10w 张图片做了一个模型 A)的深层模型,加上我的小样本数据训练的浅层模型组合成效果不错的模型。(因为深层的数据特征往往是相似的,是一些抽象的元素(横竖撇捺等),如下图)使得我的模型能够在很少的样本情况下,也能训练成效果很好的模型。

图片摘自这位up主https://www.bilibili.com/video/BV15B4y1S7u4/?spm_id_from=333.999.0.0&vd_source=769ff3753997160a1ea8b796c9cbd242

文章结果速览

一、单细胞预训练资源

首先作者搭建了一个基于3千万单细胞的图谱的预训练模型,细胞资源来自CELLxGENE collection (https://cellxgene.cziscience.com/)将每个基因视为一个词汇(token),将每个细胞视为一个句子(sentence),并使用特殊的条件词汇来表示不同的测序批次或模式。 通过特殊设计的注意力掩码(attention mask)方式,模型可以通过自回归(auto-regressive)的方式预测未知基因或未知细胞的表达值。此外scGPT还使用了一个独特的函数,即基于细胞表示向量来预测所有基因表达值,从而增强了细胞表示向量的学习。scGPT的变换器层(transformer layer)则使用了预训练好的GPT-2模型的参数,以加速收敛和提高性能。

二、scGPT用于细胞类型鉴定

开发团队在不同的数据集上进行了大量的实验,以评估scGPT在细胞类型注释方面的性能:human pancreas dataset,tumor-infiltrating myeloid dataset和immune cell的细胞类型鉴定的过程中达到了不错的效果,性能超过了现存的TOSICA和scBERT模型。

三、模拟基因扰动

测序技术和基因编辑技术的最新进展极大地促进了大规模扰动实验的实施,使得科学家能够描述细胞对各种基因扰动的响应特征。scGPT能够利用从已知实验获得的细胞响应知识,并据此推断出未知扰动的响应。通过在基因维度上运用自注意力机制,scGPT能够编码扰动基因与其他基因响应之间的复杂交互关系。

  1. 作者使用scGPT针对Perturb-seq数据库Per-turbseq中涵盖的105个基因的236种扰动,使用经过微调的scGPT来虚拟扩展扰动达到有5,565种扰动模式,并利用UMAP可视化了每个扰动的预测平均响应图3d。

  2. 作者同时使用scGPT还具备预测体内逆向扰动预测(通过分析细胞群体中不同扰动条件下产生的表型或基因表达变化,反向推测出引起特定细胞状态变化的遗传或分子干预手段),作者使用Norman数据集中39种已知扰动(占总数的18%)对scGPT进行微调。 随后对所有扰动结果进行了比对,scGPT成功预测了产生观测结果的扰动源。

四、单细胞数据整合
  1. 去除批次效应,scGPT与三种流行的整合方法:scVI、Seurat和Harmony。在在peripheral blood mononuclear cell (PBMC) 10k (two batches)等数据集上进行比较,scGPT成功地分离出了所有细胞类型(图4a)。scGPT卓越的整合性能进一步体现在其高AvgBIO得分上(0.821),比对其他方法高出约5%-10%。

  2. 单细胞多组学整合,单细胞多组学(scMultiomic)数据集结合了表观遗传、转录组和翻译活动等多种遗传调控视角,呈现出在聚合细胞表征的同时保持生物学信号的独特挑战。比较了scGPT与两种最先进的方法scGLUE13和Seurat的效果,cGPT是唯一成功为CD8+初始T细胞生成独特聚类的方法(图4b)。

五、单细胞数据解释特定细胞状态下的基因网络

GRN中的转录因子、辅因子、增强子及其目标基因之间的相互作用介导了重要的生物学过程。现有的GRN推断方法常常依赖于静态基因表达的相关性或者伪时间估计作为因果图的proxy ,scGPT通过优化基因表达的生成模型,不仅在其基因嵌入中隐式编码了这样的关系,还在注意力映射中体现了这些关系。

  1. 训练scGPT模型的基因嵌入构建的人白细胞抗原(HLA)蛋白质相似性网络。在这个零样本环境下,scGPT模型成功突出了两个对应于特征明显的HLA类别的聚类:HLA-I类和HLA-II类基因,它们分别编码参与免疫反应的不同角色的抗原提呈蛋白。

  2. 作者基于‘immune human’ dataset 进行微调,并在此数据集中探索了针对免疫细胞类型的CD基因网络。scGPT成功识别出编码T细胞激活所需的T3复合物的基因群(CD3E、CD3D和CD3G),以及B细胞信号传导相关的CD79A和CD79B,以及作为HLA-I类分子共受体的CD8A和CD8B(图5b)。

  3. 对Reactome数据库(https://reactome.org/)进行了通路富集分析,并使用严格的多重检验校正(https://mathworld.wolfram.com/BonferroniCorrection.html 和方法)识别出高置信度的“通路显著性”。scGPT独特地额外识别出22条通路。证明了scGPT在捕捉复杂基因-基因连接并在更广泛的生物学背景下揭示特定机制方面具有优越的能力。

  4. scGPT的注意力机制还能够捕获单细胞水平基因-基因相互作用。在Adamson CRISPR干扰数据集中,scGPT识别出由DDIT3(编码一个转录因子)抑制所最直接影响的前20个基因。

六、预训练单细胞数据数量对scGPT的影响

作者训练了从3万个到3300万个正常人细胞的序列数据不等,模型架构相同的预训练模型,随着预训练数据量的增加,微调模型的性能也随之提高。

此外,作者还探究“上下文特定预训练”的影响(预先在一个特定细胞类型上训练scGPT模型,然后在相似细胞类型上针对下游任务进行微调)。作者在来自各个主要器官的正常人细胞上预训练了七个器官特异性的模型,随后在COVID-19数据集上对各个模型进行微调,以检验预训练上下文的影响。结果显示预训练中模型上下文的相关性与其后续数据整合任务表现之间的清晰关联。

总结

如今,将通过训练AI这一种“黑箱生命体”,再通过微调让它自己说出自己的生命体征的研究,已经逐渐进入我们的视野。研究电子生物,做赛博生物科研,已经不是一句空话。对此,五星评论家掌管抽象的申认为,人们终会在虚拟世界中搭建一个现实世界的复制体,但这到底是人类的福音?还是无尽欲望的开始?还有待时间去检验。不过,别忘了几千年前就有人提醒过我们:

凡所有相,皆是虚妄。

文章模型架构解析及代码梳理

模型架构解析

该内容来自项目https://github.com/bowang-lab/scGPT,可以发现,该项目实际上已经上线很久,而且已经发表在与预印本上,我们直接进入模型的model.py查看模型架构。 首先是作者定义的超复杂TransformerModel

class TransformerModel(nn.Module):
    def __init__(
        self,
        ntoken: int, #表示词汇表中单词的数量,用于确定词嵌入层的大小。
        d_model: int,#表示模型的隐藏层维度(embedding尺寸)
        nhead: int,#多头注意力的头数
        d_hid: int,#MLP(多层感知器)神经元数量
        nlayers: int, #型的编码器/解码器层数
        nlayers_cls: int = 3, #可能是额外隐藏层层数
        n_cls: int = 1, #分类任务的类别数量
        vocab: Any = None, #词汇表对象,用于映射单词到整数索引
        dropout: float = 0.5, #防止过拟合的神经元丢弃率,这里设的挺高的
        pad_token: str = "<pad>", #padding字符,默认为"<pad>"
        pad_value: int = 0, #表示填充符号在词汇表中的索引值。
        do_mvc: bool = False, #否开启某个特定的多视图融合(Multiview Consensus)操作
        do_dab: bool = False, #是否使用数据增强或特征变换技术
        use_batch_labels: bool = False,  #是否在模型中利用批次标签信息
        num_batch_labels: Optional[int] = None, #指定批次标签的数量。
        domain_spec_batchnorm: Union[bool, str] = False, #是否使用特定领域的批量归一化,或者某种特殊的批量归一化方式
        input_emb_style: str = "continuous", #输入嵌入层的风格或特性
        n_input_bins: Optional[int] = None, #离散化输入特征的区间数量
        cell_emb_style: str = "cls", #细胞嵌入层的设计风格
        mvc_decoder_style: str = "inner product", #多视图融合解码器的设计方式
        ecs_threshold: float = 0.3,
        explicit_zero_prob: bool = False,
        use_fast_transformer: bool = False, #示是否使用快速Transformer结构来加速计算
        fast_transformer_backend: str = "flash", #指定了快速Transformer后端实现的方式
        pre_norm: bool = False, #快速Transformer后端实现的方式
    ):

首先作者在init方法中预先设定了一大串default,每个参数的具体意义作者已经在代码中用注释标注出来了,这里值得关注的参数有几个:

        super().__init__()
        self.model_type = "Transformer"
        self.d_model = d_model
        self.do_dab = do_dab
        self.ecs_threshold = ecs_threshold
        self.use_batch_labels = use_batch_labels
        self.domain_spec_batchnorm = domain_spec_batchnorm
        self.input_emb_style = input_emb_style
        self.cell_emb_style = cell_emb_style
        self.explicit_zero_prob = explicit_zero_prob
        self.norm_scheme = "pre" if pre_norm else "post"
        if self.input_emb_style not in ["category", "continuous", "scaling"]:
            raise ValueError(
                f"input_emb_style should be one of category, continuous, scaling, "
                f"got {input_emb_style}"
            )
        if cell_emb_style not in ["cls", "avg-pool", "w-pool"]:
            raise ValueError(f"Unknown cell_emb_style: {cell_emb_style}")
        if use_fast_transformer:
            if not flash_attn_available:
                warnings.warn(
                    "flash-attn is not installed, using pytorch transformer instead. "
                    "Set use_fast_transformer=False to avoid this warning. "
                    "Installing flash-attn is highly recommended."
                )
                use_fast_transformer = False
        self.use_fast_transformer = use_fast_transformer

接下来是将参数传入类,

        # TODO: add dropout in the GeneEncoder
        self.encoder = GeneEncoder(ntoken, d_model, padding_idx=vocab[pad_token])

这里作者展示了一个GeneEncoder方法,与作者在文章中提到的将基因转化为字向量的操作有关,让我们来看看这个方法吧:

class GeneEncoder(nn.Module):
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: Optional[int] = None,
    ):
        super().__init__()
        self.embedding = nn.Embedding(
            num_embeddings, embedding_dim, padding_idx=padding_idx
        )
        self.enc_norm = nn.LayerNorm(embedding_dim)

    def forward(self, x: Tensor) -> Tensor:
        x = self.embedding(x)  # (batch, seq_len, embsize)
        x = self.enc_norm(x)
        return x

这里面包括3个参数输入的embedding数量,embedding的维度,和padding。 其中

  • num_embeddings参数表示词汇表大小,即基因词汇表中不同的基因标记数量。

  • embedding_dim参数表示嵌入空间的维度,也就是每个基因标记将被映射到一个多维向量的空间,其维度就是这个参数的值。

  • padding_idx参数是一个可选的整数,默认为None。如果设置了这个值,它将在嵌入矩阵中对应的位置填充为全零向量,通常用于在序列填充时保持不变。

作者随后初始化了一个nn.Embedding(PyTorch中用于创建词嵌入的层),它将会创建一个权重矩阵,其中每一行对应词汇表中一个基因标记的嵌入向量。并且对nn.Embedding 进行nn.LayerNorm归一化层(有助于模型收敛和平稳训练过程)。
可以发现这和我们的transformer中encoder进行multihead之后LayerNorm非常相似。

接下来,作者写了一个value Encoder类(通常用于将原始数值或类别特征编码为具有固定维度的向量表示。在处理表格数据或非结构化数据时,可以帮助模型理解并捕获特征的潜在语义信息。)

        # Value Encoder, NOTE: the scaling style is also handled in _encode method
        if input_emb_style == "continuous":
            self.value_encoder = ContinuousValueEncoder(d_model, dropout)
        elif input_emb_style == "category":
            assert n_input_bins > 0
            self.value_encoder = CategoryValueEncoder(
                n_input_bins, d_model, padding_idx=pad_value
            )
        else:
            self.value_encoder = nn.Identity()  # nn.Softmax(dim=1)
            # TODO: consider row-wise normalization or softmax
            # TODO: Correct handle the mask_value when using scaling

这里的value_encoder的实例化根据不同类型的输入特征(由input_emb_style参数决定)进行:

  • 对于连续数值特征,使用ContinuousValueEncoder进行编码**,来捕捉连续值之间的复杂关系。

  • 对于类别特征,使用CategoryValueEncoder进行编码,常见的做法是使用嵌入层(Embedding layer),将每个类别映射到一个低维稠密向量,同时设置padding_idx用于处理填充值。

  • 如果输入特征风格不符合上述两种情况,暂时使用nn.Identity作为默认编码器,这意味着不对输入数据做任何变换,直接传递原数据。

接下来作者还定义了批处理标签编码器(Batch Label Encoder)使用与否。这里我们不主要讨论。随后,作者根据配置参数选择和实例化一个编码器(TransformerEncoder)。

        if use_fast_transformer:
            if fast_transformer_backend == "linear":
                self.transformer_encoder = FastTransformerEncoderWrapper(
                    d_model, nhead, d_hid, nlayers, dropout
                )
            elif fast_transformer_backend == "flash":
                encoder_layers = FlashTransformerEncoderLayer(
                    d_model,
                    nhead,
                    d_hid,
                    dropout,
                    batch_first=True,
                    norm_scheme=self.norm_scheme,
                )
                self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        else:
            encoder_layers = TransformerEncoderLayer(
                d_model, nhead, d_hid, dropout, batch_first=True
            )
            self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)

这里我们可以发现,如果 use_fast_transformer 设置为True,表示模型将使用快速Transformer变体。

  • 根据 fast_transformer_backend 参数判断,包装了一个基于线性注意力机制的快速Transformer编码器。还是另一种基于flash的快速Transformer的实现,同样接收类似的参数。最后将多个 FlashTransformerEncoderLayer 层封装到 TransformerEncoder 中,通过传入 encoder_layers 和 nlayers 参数确定封装的层数。

  • 如果使用传统的Transformer编码器。则实例化一个 TransformerEncoderLayer 类,传入包括模型的隐藏层维度 d_model、多头注意力的头数 nhead、隐藏层大小 d_hid、dropout概率 dropout,并设置 batch_first 为True,表示输入数据的维度顺序为 (batch_size, sequence_length, d_model)。 同样地,多层封装后,构建完整的编码器。

随后作者初始化大量decoder适应于不同后续微调任务,包括ExprDecoder(矩阵解码器),ClsDecoder(分类解码器),MVCDecoder(多模态解码器,用于多组学整合),AdversarialDiscriminator(对抗性判别器),Similarity(计算两个向量或两个特征向量集合相似度的类) ,CrossEntropyLoss最后调用模型的init_weights()方法,用来初始化模型中的权重参数。

        self.decoder = ExprDecoder(
            d_model,
            explicit_zero_prob=explicit_zero_prob,
            use_batch_labels=use_batch_labels,
        )
        self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers=nlayers_cls)
        if do_mvc:
            self.mvc_decoder = MVCDecoder(
                d_model,
                arch_style=mvc_decoder_style,
                explicit_zero_prob=explicit_zero_prob,
                use_batch_labels=use_batch_labels,
            )

        if do_dab:
            self.grad_reverse_discriminator = AdversarialDiscriminator(
                d_model,
                n_cls=num_batch_labels,
                reverse_grad=True,
            )

        self.sim = Similarity(temp=0.5)  # TODO: auto set temp
        self.creterion_cce = nn.CrossEntropyLoss()

        self.init_weights()

由于篇幅限制,剩下的代码以及大模型操作,我们在下一期再介绍吧!

  • 23
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
你好!对于 Faster R-CNN 的代码精读,你可以参考以下步骤: 1. 阅读主要的文件结构:Faster R-CNN 通常包括几个主要的文件,包括模型定义文件、数据加载文件、训练和测试脚本等。首先,了解代码的整体结构和文件之间的关系是很重要的。 2. 理解模型架构:查看模型定义文件,通常是一个包含网络结构的类或函数。在这个文件中,你可以找到网络的主要组件,如卷积层、池化层、全连接层等。仔细阅读这些组件的定义和参数设置,对整个网络的结构和运作方式有一个清晰的理解。 3. 研究损失函数:Faster R-CNN 使用一种特定的损失函数来衡量模型测与真实标签之间的差异。阅读训练脚本中的损失函数实现部分,了解如何计算损失以及如何反向传播梯度更新模型参数。 4. 数据加载与处理:Faster R-CNN 在训练和测试过程中需要加载和处理数据。查看数据加载文件,了解如何从数据集中读取图像和标签,并进行处理操作,如缩放、裁剪、归一化等。 5. 推断与测过程:Faster R-CNN 的目标是在图像中检测和定位物体。了解测试脚本中的推断和测过程,包括如何对输入图像进行前向传播,并根据测结果生成检测框和类别。 6. 调试和修改:在阅读代码的过程中,你可能会遇到一些问题或有一些想法来改进模型。尝试调试代码并进行一些修改,看看是否能够改善模型的性能或加入新的功能。 请记住,Faster R-CNN 是一个相对复杂的模型,可能需要花费一些时间来理解和熟悉代码。阅读官方的文档和参考资料,以及查找其他人的实现和解释,都是学习和理解代码的有用资源。祝你成功!

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值