总说
针对类别数目不均匀的数据,有些类图片多,有些少,如果直接训练,那么就会造成过拟合类别多的数据。最简单的方法就是重采样,直接根据每一类的数目,来重新分配权重。你想想,普通肯定是均匀概率采样的,自然数目多的图片,采样到的概率就大。
神奇的 WeightedRandomSampler
直接丢代码
# 数据集中,每一类的数目。
class_sample_counts = [150, 200, 300]
weights = 1./ torch.tensor(class_sample_counts, dtype=torch.float)
# 这个 get_classes_for_all_imgs是关键
train_targets = train_dataset.get_classes_for_all_imgs()
samples_weights = weights[train_ta