from fastNLP.io import SST2Pipe
from fastNLP import DataSetIter
from torchsampler import ImbalancedDatasetSampler
pipe = SST2Pipe()
databundle = pipe.process_from_file()
vocab = databundle.vocabs['words']
print(databundle)
print(databundle.datasets['train'][0])
print(databundle.vocabs['words'])
train_data = databundle.get_dataset('train')
train_data, test_data = train_data.split(0.015)
dev_data = databundle.get_dataset('dev')
print(len(train_data),len(dev_data),len(test_data))
tmp_data = dev_data[:10]
def callback_get_label(dataset,idx):
label = dataset[idx]['target']
return label
#对于数据集需要定义callback_get_label函数来确定每一个样本的标签值
sampler=ImbalancedDatasetSampler(tmp_data,callback_get_label=callback_get_label)
batch = DataSetIter(batch_size=3, dataset=tmp_data,
sampler=sampler)
for batch_x, batch_y in batch:
print("batch_x: ", batch_x)
print("batch_y: ", batch_y)
源码地址https://github.com/ufoym/imbalanced-dataset-sampler mnist.ipynb部分本机执行报错imbalanced_train_dataset.train_labels = np.delete(train_loader.dataset.train_labels, idx_to_del, axis=0) imbalanced_train_dataset.train_data = np.delete(train_loader.dataset.train_data, idx_to_del, axis=0)报错信息imbalanced_train_dataset不能直接修改属性,以下为修改后的构建过程
class IMB(torchvision.datasets.MNIST):
def __init__(self,transform=None, target_transform=None):
train_dataset = torchvision.datasets.MNIST('.', train=True, download=True, transform=train_transform)
train_labels = np.delete(train_loader.dataset.train_labels, idx_to_del, axis=0)
train_data = np.delete(train_loader.dataset.train_data, idx_to_del, axis=0)
self.data, self.targets = train_data, train_labels
self.transform=transform
self.target_transform=target_transform
imbalanced_train_dataset=IMB(transform=train_transform)
imbalanced_train_loader = torch.utils.data.DataLoader(
imbalanced_train_dataset, batch_size=args.batch_size, shuffle=True, **kwar
pytorch版本和mxnet版本的Imbalanced Dataset Sampler。目前自己写的mxnet版本效率极低,推荐pytorch版本,不影响sampler是在pytorch或mxnet中的使用。
def callback_get_label(dataset,idx):
label = dataset[idx][1]
return label
import torch
class ImbalancedDatasetSampler(Sampler):
"""Samples elements randomly from a given list of indices for imbalanced dataset
Arguments:
indices (list, optional): a list of indices
num_samples (int, optional): number of samples to draw
callback_get_label func: a callback-like function which takes two arguments - dataset and index
"""
def __init__(self, dataset, indices=None, num_samples=None, callback_get_label=None):
"""
torch版本__iter__与mxnet版本__iter__
不均衡的抽样思想:计算每个类别的概率,将概率值赋值给每个样本生成概率列表,长度与样本大小一致,
然后以多项分布的方式进行抽样获取样本索引值。
"""
# if indices is not provided,
# all elements in the dataset will be considered
self.indices = list(range(len(dataset))) \
if indices is None else indices
# define custom callback
self.callback_get_label = callback_get_label
# if num_samples is not provided,
# draw `len(indices)` samples in each iteration
self.num_samples = len(self.indices) \
if num_samples is None else num_samples
# distribution of classes in the dataset
label_to_count = {}
for idx in self.indices:
label = self._get_label(dataset, idx)
if label in label_to_count:
label_to_count[label] += 1
else:
label_to_count[label] = 1
# weight for each sample
weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
for idx in self.indices]
self.weights = torch.DoubleTensor(weights)
#self.weights = np.array(weights)
def _get_label(self, dataset, idx):
if self.callback_get_label:
return self.callback_get_label(dataset, idx)
else:
raise NotImplementedError
# def __iter__(self):
# return (self.indices[i] for i in torch.multinomial(
# self.weights, self.num_samples, replacement=True))
def __iter__(self):
#print(self.mxmulti(self.weights))
return (self.indices[i] for i in self.mxmulti(self.weights))
@classmethod
def mxmulti(cls,weights):
probs = np.array(weights).astype(float)
probs=probs/probs.sum()
sample_times = np.random.multinomial(len(probs), np.array(probs),)
sample_list=[]
for i, t in enumerate(sample_times):
t = t.item()
if t > 0:
sample_list.extend([i]*t)
import random
random.shuffle(sample_list)
return sample_list
def __len__(self):
return self.num_samples