今天我想实现一个这么一个抽样的dataloader, 例如我的数据有10类,每次训练的batch size中,有相等数目的各个类别的样本,例如batchsize=120, 那么10类中各个类别的样本随机抽12个,实现如下
%matplotlib inline
import logging
# Viz
import matplotlib.pyplot as plt
import numpy as np
# Data manipulation
import pandas as pd
import scanpy as sc
import seaborn as sns
import torch
import torch.nn as nn
from sklearn.manifold import TSNE
from sklearn.metrics import pairwise_distances
from sklearn.model_selection import train_test_split
import pytorch_metric_learning
import pytorch_metric_learning.utils.logging_presets as logging_presets
from torch.utils.data import DataLoader
# Main
from pytorch_metric_learning import losses, miners, samplers, testers, trainers
from collections import Counter
# Logs
logging.getLogger().setLevel(logging.INFO)
logging.info("VERSION %s" % pytorch_metric_learning.__version__)
# get data and format
adata = sc.datasets.paul15()
# can try preprocessing here...
# sc.pp.recipe_zheng17(adata)
# create dictionary of label map
label_map = dict(enumerate(adata.obs["paul15_clusters"].cat.categories))
print(label_map)
# extract of some of the most representative clusters for training/testing
clusters = [0, 1, 2, 3, 4, 5, 13, 14, 15, 16]
indices = adata.obs["paul15_clusters"].cat.codes.isin(clusters)
data, labels = adata.X[indices], adata.obs[indices]["paul15_clusters"].cat.codes.values
print(Counter(labels))
X_train, X_val, y_train, y_val = train_test_split(
data, labels, stratify=labels, test_size=0.2, random_state=77
)
print(373*0.8)
print(329*0.8)
print(Counter(y_train))
# This will be used to create train and val sets
class BasicDataset(torch.utils.data.Dataset):
def __init__(self, data, labels):
self.data = torch.from_numpy(data).float()
self.labels = labels
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return len(self.data)
# Training, validation, holdout set
train_dataset = BasicDataset(X_train, y_train)
# Set the dataloader sampler
sampler = samplers.MPerClassSampler(y_train.flatten(), m=6)
train_loader = DataLoader(train_dataset, batch_size=120, sampler=sampler)
print("*"*50)
for batch_idx, (train_data, training_labels) in enumerate(train_loader):
break
测试结果如下
但是这个地方还是有一些问题的,需要修改一下,就是这个dataloader的长度问题
可以看到这里的结果是这样的,很明显数据只有1472,但是dataloader却有833个batch,显然是不对的,这里的问题在于
所以控制一个dataloader里到底有多少个batch的在length_before_new_iter这个参数,应该改成
所以正确的代码应该是
%matplotlib inline
import logging
# Viz
import matplotlib.pyplot as plt
import numpy as np
# Data manipulation
import pandas as pd
import scanpy as sc
import seaborn as sns
import torch
import torch.nn as nn
from sklearn.manifold import TSNE
from sklearn.metrics import pairwise_distances
from sklearn.model_selection import train_test_split
import pytorch_metric_learning
import pytorch_metric_learning.utils.logging_presets as logging_presets
from torch.utils.data import DataLoader
# Main
from pytorch_metric_learning import losses, miners, samplers, testers, trainers
from collections import Counter
# Logs
logging.getLogger().setLevel(logging.INFO)
logging.info("VERSION %s" % pytorch_metric_learning.__version__)
# get data and format
adata = sc.datasets.paul15()
# can try preprocessing here...
# sc.pp.recipe_zheng17(adata)
# create dictionary of label map
label_map = dict(enumerate(adata.obs["paul15_clusters"].cat.categories))
print(label_map)
# extract of some of the most representative clusters for training/testing
clusters = [0, 1, 2, 3, 4, 5, 13, 14, 15, 16]
indices = adata.obs["paul15_clusters"].cat.codes.isin(clusters)
data, labels = adata.X[indices], adata.obs[indices]["paul15_clusters"].cat.codes.values
print(Counter(labels))
X_train, X_val, y_train, y_val = train_test_split(
data, labels, stratify=labels, test_size=0.2, random_state=77
)
print(373*0.8)
print(329*0.8)
print(Counter(y_train))
# This will be used to create train and val sets
class BasicDataset(torch.utils.data.Dataset):
def __init__(self, data, labels):
self.data = torch.from_numpy(data).float()
self.labels = labels
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return len(self.data)
# Training, validation, holdout set
train_dataset = BasicDataset(X_train, y_train)
print(X_train.shape)
# Set the dataloader sampler
sampler = samplers.MPerClassSampler(y_train.flatten(), m=6,batch_size=60,length_before_new_iter=len(train_dataset))
train_loader = DataLoader(train_dataset, batch_size=128, sampler=sampler)
print(len(train_loader))
print("*"*50)
count=0
for batch_idx, (train_data, training_labels) in enumerate(train_loader):
count=count+1
print(count)
print(X_train.shape[0]/128)
结果如下
注意这个train_loader的batch_size设置有点问题需要注意
from collections import Counter
import numpy as np
from torch.utils.data import DataLoader
import torch
from pytorch_metric_learning import losses, miners, samplers, testers, trainers
# This will be used to create train and val sets
class BasicDataset(torch.utils.data.Dataset):
def __init__(self, data, labels):
self.data = torch.from_numpy(data).float()
self.labels = labels
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return len(self.data)
# This will be used to create train and val set
class_id =[i for i in range(12)]
num_sub_class =np.array([67, 58, 55, 45, 44, 41, 39, 39, 34, 26, 22, 5])
label_list =[]
for i in range(12):
label_list =label_list + [i]*num_sub_class[i]
y_train =np.array(label_list)
train_dataset = BasicDataset(y_train, y_train)
# Set the dataloader sampler
sampler = samplers.MPerClassSampler(y_train.flatten(), m=30,length_before_new_iter=len(train_dataset))
train_loader = DataLoader(train_dataset, batch_size=128, sampler=sampler)
print(len(train_loader))
print("*"*50)
count=0
for batch_idx, (train_data, training_labels) in enumerate(train_loader):
count=count+1
print(Counter(training_labels.data.numpy()))
print(train_data.shape)
print(count)
结果如下
可以看到这里设置m=30时,会导致一些类别其实并没有抽样,这个需要注意
from collections import Counter
import numpy as np
from torch.utils.data import DataLoader
import torch
from pytorch_metric_learning import losses, miners, samplers, testers, trainers
# This will be used to create train and val sets
class BasicDataset(torch.utils.data.Dataset):
def __init__(self, data, labels):
self.data = torch.from_numpy(data).float()
self.labels = labels
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return len(self.data)
# This will be used to create train and val set
class_id =[i for i in range(12)] # 12个类别
num_sub_class =np.array([67, 58, 55, 45, 44, 41, 39, 39, 34, 26, 22, 5])
label_list =[]
for i in range(12):
label_list =label_list + [i]*num_sub_class[i]
y_train =np.array(label_list)
train_dataset = BasicDataset(y_train, y_train)
# Set the dataloader sampler
sampler = samplers.MPerClassSampler(y_train.flatten(), m=10,length_before_new_iter=len(train_dataset))
train_loader = DataLoader(train_dataset, batch_size=128, sampler=sampler)
print(len(train_loader))
print("*"*50)
count=0
for _ in range(2):
for batch_idx, (train_data, training_labels) in enumerate(train_loader):
count=count+1
print(Counter(training_labels.data.numpy()))
print(train_data.shape)
print(count)
如果要保证每次完全抽样的类别中每个类别的样本数一样,就保证m*mper= train_loader的batch_size
from collections import Counter
import numpy as np
from torch.utils.data import DataLoader
import torch
from pytorch_metric_learning import losses, miners, samplers, testers, trainers
# This will be used to create train and val sets
class BasicDataset(torch.utils.data.Dataset):
def __init__(self, data, labels):
self.data = torch.from_numpy(data).float()
self.labels = labels
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return len(self.data)
# This will be used to create train and val set
class_id =[i for i in range(12)] # 12个类别
num_sub_class =np.array([67, 58, 55, 45, 44, 41, 39, 39, 34, 26, 22, 5])
label_list =[]
for i in range(12):
label_list =label_list + [i]*num_sub_class[i]
y_train =np.array(label_list)
train_dataset = BasicDataset(y_train, y_train)
# Set the dataloader sampler
sampler = samplers.MPerClassSampler(y_train.flatten(), m=10,length_before_new_iter=len(train_dataset))
train_loader = DataLoader(train_dataset, batch_size=120, sampler=sampler)
print(len(train_loader))
print("*"*50)
count=0
for _ in range(2):
for batch_idx, (train_data, training_labels) in enumerate(train_loader):
count=count+1
print(Counter(training_labels.data.numpy()))
print(train_data.shape)
print(count)