CReST: A Class-Rebalancing Self-Training Framework for Imbalanced Semi-Supervised Learning. (CVPR, 2021)
解决问题:半监督场景下的长尾问题
首先实验中发现
模型对头部类别的样本 Recall 很高,对尾部类别的 Recall 很低;模型对头部类别样本的 Precision 很低,但对尾部类别的 Precision 却很高。这是一个很常见的类别不平衡问题里的过拟合现象,换句话来说,模型对不确定性很高的尾部类别样本都预测成头部类别了。举个例子,我在训练阶段喂入模型100张猫的图片以及10张狗的图片,在测试阶段时会发现对于模型把握不准的狗的图片都会预测成猫,只有模型特别有把握的狗的图片才会预测成狗,此时会造成猫这个类别的 Recall 会非常高 Precision 却会非常低,反之狗这个类别的 Recall 会非常低但 Precision 却会非常高。
图示
1.使用标准的 SSL 算法利用已标记集和未标记集的信息训练一个有效的模型。
2.给未标记集中的每个样本打上伪标记得到新的数据集。
3.挑选出模型的预测类别属于尾部类别的样本作为候选集加入到已标记集合中。
可以这样做的原因
模型预测的类别属于尾部类别意味着这些样本的伪标记具有很高的置信度的(High precision),因为此时的模型是对头部类别过拟合的,此时模型还将某一样本预测为尾部类别说明该伪标记真的是该样本的 ground-truth。从另一方面,这一采样又巧妙的引入了尾部类别样本,从而缓解了类别不平衡问题。
仍然存在的问题
这些被挑选出来的样本虽然有很大的可能具有正确的伪标记,但它可能不太具备代表性,即不能很好的代表这个类。换句话来说,模型对这些样本具有很大的置信度,即这些极可能是简单样本,对模型的学习帮助性可能不大,因此此时模型已经很确信能将其预测对了,此时再引入这些样本的loss其实很小,对模型的影响也不大。
针对前面所提到的,所以我认为可能性能的提升绝大部分来自于类别平衡了,当然正确的简单样本的引入也会对模型性能提升有帮助。
这个方法由此也会在半监督场景下作用明显,因此本来就没啥有标记样本,还如此的类别不平衡,此时给一些正确标记的虽然简单的样本对模型训练也是很有帮助的。