动机
不平衡学习是一种机器学习范例,其中分类器必须从具有倾斜的类分布的数据集中学习。不平衡的数据集可能对分类器的性能产生不利影响。
重新平衡数据集是处理类不平衡的一种方法。这可以通过以下方式完成:
- 采样不足的普通类。
- 对稀有类进行过度采样。
- 两者兼而有之。
PyTorch提供了一些用于重新平衡数据集的实用程序,但它们仅限于已知长度的批处理数据集(即,它们要求数据集具有__len__方法)。诸如ufoym / imbalanced-dataset-sampler之类的社区贡献很可爱,但它们也仅适用于批处理数据集(在PyTorch行话中也称为地图样式数据集)。 pytorch / pytorch存储库上还存在一个GitHub问题,但它似乎不太活跃。
因此,该存储库实现了包装IterableDataset的数据重采样器。在此拉取请求中,后者已添加到PyTorch。特别是,提供的方法不需要您必须事先知道数据集的大小。每种方法都适用于二进制和多类分类。
安装
$ pip install pytorch_resample
用法
作为一个正在运行的示例,我们将定义一个IterableDataset,它对scikit-learn的make_classification函数的输出进行迭代。
>>> from sklearn import datasets>>> import torch>>> class MakeClas