pytorch处理数据类别不均匀问题

pytorch处理数据类别不均匀问题

参考资料git地址

一:问题类型

最近在训练模型时出现了loss函数值为nan的情况,即出现了训练不收敛问题,前后又整理了一遍思路忽然记起本次的训练数据有及大的类别不均匀的问题,在没有经过处理前是直接用的乱序读入shuffle=True,对数据进行Dataset Sampler之后问题得到解决,最终训练出了准确率较高的模型。

二:什么是Dataset Sampler

在pytorch中有很多此类方法,比如WeightedRandomSampler啊等等使数据平衡的方法,简单的说也就是给每类数据赋予一定的权重,使读入的数据达到平衡

三:Imbalanced Dataset Sampler

本包的git地址在文章开头已经给出
这就是今天拿出来分享的一个git上开源的包,其源码思想很巧妙,也非常便利,非常适合拿来直接上手,但是值得注意的是,它目前所支持的数据集格式有限,我们可以goto到源码如下:

class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
    """Samples elements randomly from a given list of indices for imbalanced dataset
    Arguments:
        indices (list, optional): a list of indices
        num_samples (int, optional): number of samples to draw
        callback_get_label func: a callback-like function which takes two arguments - dataset and index
    """

    def __init__(self, dataset, indices=None, num_samples=None, callback_get_label=None):
                
        # if indices is not provided, 
        # all elements in the dataset will be considered
        self.indices = list(range(len(dataset))) \
            if indices is None else indices

        # define custom callback
        self.callback_get_label = callback_get_label

        # if num_samples is not provided, 
        # draw `len(indices)` samples in each iteration
        self.num_samples = len(self.indices) \
            if num_samples is None else num_samples
            
        # distribution of classes in the dataset 
        label_to_count = {}
        for idx in self.indices:
            label = self._get_label(dataset, idx)
            if label in label_to_count:
                label_to_count[label] += 1
            else:
                label_to_count[label] = 1
        # weight for each sample
        weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
                   for idx in self.indices]
        self.weights = torch.DoubleTensor(weights)

    def _get_label(self, dataset, idx):
        if isinstance(dataset, torchvision.datasets.MNIST):
            return dataset.train_labels[idx].item()
        elif isinstance(dataset, torchvision.datasets.ImageFolder):
            return dataset.imgs[idx][1]
        elif self.callback_get_label:
            return self.callback_get_label(dataset, idx)
        else:
            raise NotImplementedError
           

    def __iter__(self):
        return (self.indices[i] for i in torch.multinomial(
            self.weights, self.num_samples, replacement=True))

    def __len__(self):
        return self.num_samples

在_get_label方法中可以看到,如果你想用自己的数据集,需要补全参数callback_get_label,当然,如果你的数据集是按照pytorch官方教程来写的,格式与torchvision.datasets.ImageFolder一样,那就很容易了,只需要在else下方将源码修改为:return dataset.imgs[idx][1],这样就大功告成了~

怎么安装

在之前给出的git网址中已经给出具体方法,为方便在此处说明一下:
下载解压后,调出终端,移动到刚才解压的文件夹下运行以下命令:

pip install .

python setup.py install
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值