如何将anndata拆分为k份
import scanpy as sc
import anndata as ad
class KSplitAnndata:
@staticmethod
def _base_split(data: ad.AnnData, k: int) -> list:
adata = data.copy()
num = adata.n_obs // k
adata_list = []
for i in range(k):
if num <= adata.n_obs:
adata_list.append(sc.pp.subsample(adata, n_obs=num, copy=True))
adata = adata[~adata.obs_names.isin(adata_list[i].obs_names)]
else:
adata_list.append(adata)
return adata_list
@staticmethod
def k_split(data: ad.AnnData, k: int, batch_key: str = None) -> list:
"""
:param data: anndata object
:param k: k fold
:param batch_key: split by batch, default is None
"""
adata = data.copy()
if batch_key:
adata_batch_list = []
adata_list = []
batch_tuple = set(adata.obs[batch_key])
for batch in batch_tuple:
adata_batch = adata[adata.obs[batch_key]==batch]
adata_batch_list.append(KSplitAnndata._base_split(adata_batch, k))
for i in range(k):
adata_list.append(ad.concat([adata_batch[i] for adata_batch in adata_batch_list]))
else:
adata_list = KSplitAnndata._base_split(adata, k)
return adata_list
# example
adata_list = KSplitAnndata.k_split(adata, 5, "batch")