介绍
单细胞测序技术越来越流行,很多文章都在做单细胞测序数据的分析工作,在2022年腾讯AI实验室发表了一篇scBERT的文章,就是用深度学习模型对单细胞数据进行预训练,并进行下游任务设计,比如单细胞的类型注释任务。本文主要是记录我对这篇文章预训练和分类任务部分的理解与实现。旨在理解BERT模型框架在单细胞数据中的运用,并根据数据的生物学特征,进行模型框架的调整。
原文介绍
原文标题及链接:scBERT as a large-scale pretrained deep language model for cell type annotation of single-cell RNA-seq data,https://www.nature.com/articles/s42256-022-00534-z。原文主要是利用BERT的框架在自然语言处理上的方式,类别来处理单细胞的gene表达数据。比如,将一个单细胞样本的基因表达值作为token,一个样本就是一句话,用BERT的框架学习,gene表达值之间的相互关系,得到预训练模型,然后进行下游任务。
流程框架
此框架分为两个部分:自监督的预训练学习任务和有监督的微调任务
在自监督的预训练学习任务部分中,作者使用了100万多个没有标签的单细胞样本进行学习,这些样本的一共有超过16906个gene。另外为了模拟自然语言处理模型的方式,引入了gene与gene之间的相互关系矩阵来替代position embedding。
在有监督的微调任务,作者另外收集了5个不同study的单细胞测序数据,并且有标签,进行单细胞注释任务训练,同时观察,经过预训练模型后,是否可以消除不同study之间的批次效应。
数据处理
单细胞的gene表达矩阵是n x m的,n表示单细胞个数,m表示gene的个数, a i j a_{ij} aij 表示gene表达值,为连续型数值。
( a 11 ⋯ a 1 m ⋮ ⋱ ⋮ a n 1 ⋯ a n m ) \begin{pmatrix} a_{11} & \cdots & a_{1m} \\ \vdots & \ddots & \vdots \\ a_{n1} & \cdots & a_{nm} \end{pmatrix}
a11⋮an1⋯⋱⋯a1m⋮anm
要知道,在NLP中,字符是离散的,很容易token化。对于连续型的gene表达值,作者将其binning化,即分箱处理。比如一个细胞的所有gene表达值范围在[0, 12.3],我们取5个bin,那么该细胞的表达值范围就为[0, 1, 2, 3, 4, 5]。即对gene表达值取整,大于5的表达值都归为5。然后将这些表达值token化,即每个单细胞有6个token,分别为0,1,2,3,4,5。
接下来,根据BERT模型,我们需要随机mask token,即将某些gene的表达值mask掉,自监督任务即要学习预测这些mask掉的token,从而建立起token与token之间的关系。
接下来表达值token的embedding表示,token初始化embedding向量,比如每一个表达值token用100维向量表示。这样处理后,每一个样本的形状为 m × 100 m\times100 m×100,所有样本的形状为 n × m × 100 n\times m\times 100 n×m×100.
最后根据BERT模型,将表达值token embedding 与 gene与gene关系的embedding 对应元素相加,输入transformer模型,学习表达值token的embedding,预测mask掉的表达值token。
代码实现
原文代码使用的是performer框架,是对长文本序列处理比较高效的框架,pytorch的transformer最大允许输入长度为512。作者用这个框架,在多块V100 GPU 耗时1-2 周才训练完成,代码地址为:https://github.com/TencentAILabHealthcare/scBERT。我使用的数据也是在这个地址获取的。
改进
根据gene在细胞中表达的规律特性,我们对数据处理部分进行改进:
- 并不是每个细胞都表达16906个gene,即有很多gene的表达值为0,实际上有数据的gene,中位数在500左右。因此采用padding的方式将长短不一的非0表达值gene补齐
- 对于100多万数据,即使只有500个gene长度,普通的用户,包括我也没有足够的GPU资源可以完全运行起来。因为gene之间是没有顺序关系的,因此我们可以每次迭代随机选取一部分的gene作为输入,而且可以设置每个gene被选中的频率,这部分想法参考了scformer,https://www.biorxiv.org/content/10.1101/2022.11.20.517285v1.full的文章。这样做后可以增大batch_size,能够在一张A100 40GB的显卡上跑起来。
- 由于改进了输入长度,我们可以利用pytorch的transformer框架来实现
- 在生物学上,gene没有像自然语言有明确的顺序,因此可以不考虑position embedding的输入
实现代码
下面是我根据方法描述,使用pytorch的transformer框架实现的代码
预训练单细胞数据:https://drive.google.com/file/d/1fNZbKx6LPeoS0hbVYJFI8jlDlNctZxlU/view?usp=sharing
为了能够学习模型,将模型跑起来,我们可以选取10000个单细胞进行测试
数据处理部分
- 数据输入格式为h5ad格式
- 可以根据gene表达频率选择gene
- 使用padding的方式补齐
# preprocess.py
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import scanpy as sc
class Preprocessor(object):
def __init__(self, n_binning):
self.n_binning = n_binning
def __call__(self,