不均衡样本的sampler构建 Imbalanced Dataset Sampler

from fastNLP.io import SST2Pipe
from fastNLP import DataSetIter
from torchsampler import ImbalancedDatasetSampler
pipe = SST2Pipe()
databundle = pipe.process_from_file()
vocab = databundle.vocabs['words']
print(databundle)
print(databundle.datasets['train'][0])
print(databundle.vocabs['words'])

train_data = databundle.get_dataset('train')
train_data, test_data = train_data.split(0.015)
dev_data = databundle.get_dataset('dev')
print(len(train_data),len(dev_data),len(test_data))
tmp_data = dev_data[:10]

def callback_get_label(dataset,idx):
    label = dataset[idx]['target']
    return label

#对于数据集需要定义callback_get_label函数来确定每一个样本的标签值
sampler=ImbalancedDatasetSampler(tmp_data,callback_get_label=callback_get_label)
batch = DataSetIter(batch_size=3, dataset=tmp_data,
                    sampler=sampler)
for batch_x, batch_y in batch:
    print("batch_x: ", batch_x)
    print("batch_y: ", batch_y)

源码地址https://github.com/ufoym/imbalanced-dataset-sampler mnist.ipynb部分本机执行报错imbalanced_train_dataset.train_labels = np.delete(train_loader.dataset.train_labels, idx_to_del, axis=0) imbalanced_train_dataset.train_data = np.delete(train_loader.dataset.train_data, idx_to_del, axis=0)报错信息imbalanced_train_dataset不能直接修改属性,以下为修改后的构建过程

class IMB(torchvision.datasets.MNIST):
    def __init__(self,transform=None, target_transform=None):
        train_dataset = torchvision.datasets.MNIST('.', train=True, download=True, transform=train_transform)
        train_labels = np.delete(train_loader.dataset.train_labels, idx_to_del, axis=0)
        train_data = np.delete(train_loader.dataset.train_data, idx_to_del, axis=0)
        self.data, self.targets = train_data, train_labels
        self.transform=transform
        self.target_transform=target_transform

imbalanced_train_dataset=IMB(transform=train_transform)

imbalanced_train_loader = torch.utils.data.DataLoader(
    imbalanced_train_dataset, batch_size=args.batch_size, shuffle=True, **kwar

pytorch版本和mxnet版本的Imbalanced Dataset Sampler。目前自己写的mxnet版本效率极低,推荐pytorch版本,不影响sampler是在pytorch或mxnet中的使用。

def callback_get_label(dataset,idx):
    label = dataset[idx][1]
    return label
import torch
class ImbalancedDatasetSampler(Sampler):
    """Samples elements randomly from a given list of indices for imbalanced dataset
    Arguments:
        indices (list, optional): a list of indices
        num_samples (int, optional): number of samples to draw
        callback_get_label func: a callback-like function which takes two arguments - dataset and index
    """

    def __init__(self, dataset, indices=None, num_samples=None, callback_get_label=None):
        """
        torch版本__iter__与mxnet版本__iter__
        不均衡的抽样思想:计算每个类别的概率,将概率值赋值给每个样本生成概率列表,长度与样本大小一致,
        然后以多项分布的方式进行抽样获取样本索引值。
        """
                
        # if indices is not provided, 
        # all elements in the dataset will be considered
        self.indices = list(range(len(dataset))) \
            if indices is None else indices

        # define custom callback
        self.callback_get_label = callback_get_label

        # if num_samples is not provided, 
        # draw `len(indices)` samples in each iteration
        self.num_samples = len(self.indices) \
            if num_samples is None else num_samples
            
        # distribution of classes in the dataset 
        label_to_count = {}
        for idx in self.indices:
            label = self._get_label(dataset, idx)
            if label in label_to_count:
                label_to_count[label] += 1
            else:
                label_to_count[label] = 1
                
        # weight for each sample
        weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
                   for idx in self.indices]

        self.weights = torch.DoubleTensor(weights)
        #self.weights = np.array(weights)
       

    def _get_label(self, dataset, idx):
        if self.callback_get_label:
            return self.callback_get_label(dataset, idx)
        else:
            raise NotImplementedError
                
#     def __iter__(self):
#         return (self.indices[i] for i in torch.multinomial(
#             self.weights, self.num_samples, replacement=True))

    def __iter__(self):
        #print(self.mxmulti(self.weights))
        return (self.indices[i] for i in self.mxmulti(self.weights))
    
    @classmethod
    def mxmulti(cls,weights):
        probs = np.array(weights).astype(float)
        probs=probs/probs.sum()
        sample_times = np.random.multinomial(len(probs), np.array(probs),)
        sample_list=[]
        for i, t in enumerate(sample_times):
            t = t.item()
            if t > 0:
                sample_list.extend([i]*t)
        import random
        random.shuffle(sample_list)
        return sample_list
    
    def __len__(self):
        return self.num_samples

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值