在许多机器学习应用程序中,我们经常遇到一些数据集,其中某些类型的数据可能比其他类型的数据更容易被看到。以罕见病鉴定为例,正常标本可能多于病标本。在这些情况下,我们需要确保经过训练的模型不会偏向于拥有更多数据的类。例如,考虑一个数据集,其中有5张疾病图像和20张正常图像。如果模型预测所有图像均为正常,则准确率为80%,该模型的f1分为0.88分。因此,该模型有向“正常”类倾斜的高趋势。为了解决这个问题,一种被广泛采用的技术被称为重采样。它包括从多数类中删除样本(抽样不足)和/或从少数类中添加更多示例(抽样过多)。尽管平衡类课程有其优势,但这些技术也有其弱点(没有免费的午餐)。过最简单的实现是复制来自少数类的随机记录,这可能导致过度拟合。最简单的方法是从多数类中删除随机记录,这会导致信息丢失。
我们实现了一个易于使用的PyTorch采样器ImbalancedDatasetSampler,它能够从不平衡的数据集采样时重新平衡类分布自动估计采样权值避免创建新的平衡数据集当它与数据扩充技术一起使用时,减轻过度拟合
from sampler import ImbalancedDatasetSampler
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=ImbalancedDatasetSampler(train_dataset),
batch_size=args.batch_size,
**kwargs
)