论文阅读4--CReST: A Class-Rebalancing Self-Training Frameworkfor Imbalanced Semi-Supervised Learning

目录

写在前面

0.Abstract

1.Introduction

2.Related work

2.1. Semi-supervised learning

2.2. Class-imbalanced supervised learning

2.3. Class-imbalanced semi-supervised learning

3. Class-Imbalanced SSL(类不平衡的半监督学习)

3.1. Problem setup and baselines

3.2. A closer look at the model bias

3.3. Class-rebalancing self-training

3.4. Progressive distribution alignment

4. Experiments

4.1CIFAR-LT

4.2. ImageNet127

4.3. Ablation study

5. Conclusion

6.启发

7.讨论


写在前面

   知识补充

不平衡学习

最常用的方法是根据类别样本量重新平衡培训目标。其中两种方法具有代表性: a) 重新加权,其通过将相对较高的成本分配给次要类别的示例来影响损失函数   b) 重新采样,其通过对少数类的过采样或对多数类的过采样,或两者都直接调整标签分布,以获得平衡的采样分布。然而,天真地重新平衡目标通常会导致过度适应少数群体。最近,通过将特征从多数类转移到代表性不足的少数类,也提出了基于转移学习的方法。但是,这些方法假定所有标签都可用,并且不能直接应用于SSL方案。

长尾分布(Long-Tailed Distribution)

自然界中收集的样本通常呈长尾分布,即收集得到的绝大多数样本都属于常见的头部类别(例如猫狗之类的),而绝大部分尾部类别却只能收集到很少量的样本(例如熊猫、老虎),这造成收集得到的数据集存在着严重的类别不平衡问题(Class-Imbalanced),从而使得训练得到的模型严重的过拟合于头部类别。

对于解决长尾分布的方法有很多,例如重采样 (Re-Sampling) 以及重加权 (Re-Weighting)。重采样简单来说可以划分为两类,

  • 一是通过对头部类别进行「欠采样」减少头部类别的样本数,

  • 二是通过「过采样」对尾部类别进行重复采样增加其样本数,从而使得类别“平衡”。

但这样naive的方法存在的缺点也显而易见,即模型对尾部类别过拟合以及对头部类别欠拟合。

重加权方法的核心思想是类别少的样本应该赋予更大的权重,类别多的样本赋予更少的权重。此外有一篇文章[1]提出样本之间存在大量的信息冗余,因此提出了一个类别「有效样本数」的概念,还挺有意思,这里就不展开了。

半监督学习

由于数据标注过程的高成本,在许多任务中很难获得强有力的监督信息,半监督学习的目标是利用未标记的数据得到更好表现的模型。而未标记的数据集可以为我们提供关于数据真实分布的一些额外信息

0.Abstract

      基于类不平衡数据的半监督学习虽然是一个现实问题,但仍在研究中。虽然已知现有的半监督学习(semi-supervised learning, SSL)方法在少数类上表现不佳,但我们发现,它们仍然在少数类上生成高精度的伪标签。通过利用这一特性,我们提出了类平衡自训练(CReST),一个简单而有效的框架来改进现有的类不平衡数据SSL方法。CReST通过从一个未标记的集合中添加伪标记的样本来迭代地重新训练一个基线SSL模型,在该集合中,根据一个估计的类别分布,更频繁地选择来自少数群体的伪标记样本。我们还提出了一种渐进分布对齐来自适应调整再平衡强度,称为CReST+。我们展示了CReST和CReST+在各种类不平衡数据集上改进了最先进的SSL算法,并始终优于其他流行的再平衡方法。

1.Introduction

     半监督学习(Semi-supervised learning, SSL)利用无标记数据来提高模型性能,并在标准 SSL 图像分类基准上取得了很有希望的结果。在构建 SSL基准数据集期间,通常会隐式地做出一个常见的假设,即标记数据和/或未标记数据的类分布是平衡的。然而,在许多现实场景中,这种假设是不正确的,并且成为了导致SSL性能低下的主要原因。

     关于不平衡数据的监督学习已经得到了广泛的研究。通常可以观察到,在不平衡数据上训练的模型偏向于具有大量例子的多数类,而远离具有很少例子的少数类。人们提出了各种解决方案来帮助缓解偏差,如重采样、重加权和两阶段训练。所有这些方法都依赖于标签来重新平衡有偏差的模型。

     相比之下,对非平衡数据的 SSL研究还不够深入。事实上,数据不平衡在 SSL 中带来了进一步的挑战,因为缺少标签信息妨碍了对未标签集的重新平衡。SSL算法中通常使用经过标记数据训练的模型生成的未标记数据的伪标签。然而,如果伪标签是由在不平衡数据上训练的初始模型生成的,并且偏向于大多数类,那么伪标签可能会有问题:随后使用这种有偏差的伪标签进行训练,会加剧这种偏差并恶化模型质量。除了最近的一些研究,现有的大多数SSL算法都没有对不平衡的数据分布进行充分的评估。

                       

     如图 1(a)所示,标记集和未标记集的不平衡类分布大致相同的不平衡数据。我们观察到,现有 SSL 算法在不平衡数据上的不理想性能主要是由于少数类的低召回率。我们的方法是受到进一步观察的推动,尽管如此,对少数族裔的精确程度还是惊人的高。在图 1(b) 中,我们展示了对 FixMatch生成的CIFAR10-LT 数据集的预测,FixMatch是一个具有代表性的SSL算法,在平衡基准上具有最先进的性能。该模型在多数类别上的召回率较高,但在少数类别上的召回率较低,导致在平衡测试集上的总体准确率较低。然而,该模型在少数群体类别上的精度几乎是完美的,这表明该模型在将样本划分为少数群体类别方面是保守的,但一旦它做出这样的预测,我们就可以相信它是正确的。对其他 SSL 方法和监督学习也进行了类似的观察。

      考虑到这一点,我们引入了一个类重新平衡自训练方案(CReST),该方案在自适应地从无标签集合中采样伪标签数据以补充原始标签集合后,重新训练一个基线 SSL 模型。我们将每个经过充分训练的基线模型视为一个迭代。在每一代之后,从无标签的集合中添加伪标签的样本到有标签的集合中,以重新训练SSL模型。我们不是用所有的伪标签样本更新标记集,而是使用随机更新策略,在这种策略中,如果样本被预测为少数类,则被选择的概率更高,因为这些样本更有可能是正确的预测。更新概率是由标记集估计的数据分布的函数。此外,我们将 CReST 扩展到CReST+,将分布对齐与温度比例因子结合,以控制其几代后的对齐强度,从而更积极地调整预测数据分布,以缓解模型偏差。如图 1(c)和 1(d)所示,所提出的策略降低了伪标记的偏倚,从而提高了类平衡测试集的准确性。

     我们在实验中表明,CReST和CReST+比基线SSL方法有很大的改进。在 CIFAR-LT 上,我们的方法在不同的不平衡比例和标签分数下比 FixMatch准确率高 11.8%。我们的方法在 MixMatch和 FixMatch上的准确率也超过了 DARP, DARP是一种最先进的 SSL算法,用于从不平衡数据中学习。为了进一步测试该方法在大规模数据上的有效性,我们将该方法应用于 ImageNet通过基于语义层次的类合并创建的自然不平衡数据集 ImageNet127上,召回率提高了7.9%。广泛的消融研究进一步表明,我们的方法特别有助于提高少数族裔的记忆,使其成为一种可行的解决 SSL不平衡的方法。

2.Related work

2.1. Semi-supervised learning

     近年来,SSL 研究取得了重大进展。这些方法中有许多具有与深度学习相似的基本技术,如熵最小化、伪标记或一致性正则化。伪标记利用模型自身预测得到的伪标记目标,用未标记数据训练分类器。相关的,使用带有温度标度的模型预测概率作为软伪标签?。一致性正则化通过提高未标记数据的不同视图之间预测的一致性来学习分类器,无论是通过软还是硬伪标签。生成多视图的有效方法有:变强度的输入数据增强、网络层内的标准 dropout、随机深度。最近的SSL方法的性能依赖于伪标签的质量。然而,上述著作中都没有研究阶级不平衡环境下的 SSL,在这种环境下伪标签的质量受到模型偏差的严重威胁。

2.2. Class-imbalanced supervised learning

    对类别失衡的监督学习的研究越来越受到关注。杰出的作品包括 re-sampling和 reweighting,它们重新平衡了每个类的贡献,而另一些则专注于重新权衡每个实例。一些工作旨在将知识从多数阶级转移到少数阶级(迁移学习)。最近的一个研究趋势提出将表示学习和分类器学习去耦。这些方法假设在培训期间所有标签都是可用的,并且它们的性能在SSL场景下基本上是未知的。

2.3. Class-imbalanced semi-supervised learning

   虽然SSL得到了广泛的研究,但它对于类不平衡数据的探索还不够。最近,Yang 和Xu认为,利用SSL和自我监督学习的无标签数据可以有利于类别不平衡学习。Hyun 等人提出了抑制一致性损失来抑制少数群体的损失。Kim等人提出了分布式校准精炼厂(DARP),通过凸优化来细化原始伪标签。相反,我们通过类再平衡抽样策略和渐进分布对齐策略直接提高模型原始伪标签的质量。DARP 还讨论了另一种有趣的设置,即标记数据和未标记数据不共享相同的类别分布,而在本工作中,我们关注的是标记数据和未标记数据具有大致相同的分布的情况。 

3. Class-Imbalanced SSL(类不平衡的半监督学习)

    在本节中,我们首先设置问题并介绍基线 SSL算法。接下来,我们研究了现有 SSL算法对类不平衡数据的偏差行为。基于这些观察结果,我们提出了一个类 -再平衡自我训练框架(CReST),该框架利用了模型的偏差,而不是受到模型的影响,以缓解少数类的性能退化。此外,我们对分布对齐进行了扩展,并将其集成为CReST+,进一步提高了在线伪标注的质量。

3.1. Problem setup and baselines

     我们首先提出了类别不平衡半监督学习的问题。对于L-class分类任务,有一个标记集X=(x_{n}, y_{n}): n∈(1,…,N);其中x_{n}R^{d}为训练示例,y_{n}∈{1,…, L}为相应的类标号。类l的X中训练例数记为N_{l},  i.e.\sum_{l=1}^{L} N_{l}=N,不失一般性,我们假设类按基数降序排序,即N1≥N2≥···≥NL。X的边缘阶级分布是倾斜的,即N_{1}>>N_{L}.用不平衡比来衡量等级不平衡程度,\gamma =\frac{N_{1}}{N_{L}}.

     除了有标记的集合X外,一个无标记的集合U=u_{m}R^{d}:m∈(1,…,M)也提供了与X相同的类分布。(标记数据与未标记数据类分布相同)标签分数β =\frac{N}{N+M}度量了标签数据的百分比。给定类不平衡集合X和U,我们的目标是学习一个分类器f: R^{d}→{1,…, L},在类平衡检验准则下具有很好的泛化性。

     许多最先进的SSL方法通过为分类器的预测分配伪标签ˆy_{m}=f\left ( u_{m} \right )来利用未标记的数据。然后,分类器在有标签和无标签的样本及其相应的伪标签上进行优化。因此,伪标签的质量对最终性能至关重要。这些算法成功地在标准的类平衡数据集上工作,因为分类器的质量-因此它的在线伪标签-在训练过程中提高了所有类别的性能。问题:然而,当分类器一开始由于类分布的偏置而产生偏置时,未标记数据的在线伪标签可能更加偏置,进一步加剧了类的不平衡问题,导致少数类性能严重下降。

3.2. A closer look at the model bias

   之前的工作引入了具有不同类不平衡比率的CIFAR数据集的长尾版本,以评估类不平衡全监督学习算法。我们通过保留部分已标记的训练样本和其余未标记的训练样本来扩展该协议。我们测试FixMatch,它是为类平衡数据设计的最先进的SSL算法之一。图2为不平衡比γ=100,标签分数β =10%的CIFAR10LT和不平衡比γ=50,标签分数β =30%的CIFAR100-LT各类的测试查全率和查准率。

      FixMatch模型在类不平衡数据上的偏差。左图:CIFAR10-LT的每级召回和精度。图:CIFAR100-LT 的每级召回和精度。类索引是根据示例的数量降序排序的。虽然传统的假设可能是大多数人的表现比少数人要好,但我们发现这只是部分正确。该模型在多数类别上的召回率较高,但准确率较低;在少数类别上的召回率较低,但准确率较高。更多细节见第 3.2节。

     首先,如图2的第一个和第三个情节所示,FixMatch在多数类上的召回率很高,而在少数类上的查全率很低,这与传统观点是一致的。例如,CIFAR10-LT最多数类和第二多数类的召回率分别为98.5%和99.7%,而从最少数类中,模型只能正确识别8.4%的样本。换句话说,该模型高度偏向大多数类别,导致所有类别的平均召回率较低,这也被称为准确性,因为测试集是平衡的。

    尽管召回率较低,但少数群体的准确率却出奇地高,如图2的第二和第四幅图所示。例如CIFAR10-LT中少数类最多的模型精度达到97.7%,第二少数类的模型精度达到98.3%,而大多数类的模型精度相对较低。这表明许多少数群体样本被预测为多数群体之一。

    虽然传统方法可能表明,大多数类的表现比少数类的表现更好,但我们发现这只是部分正确:从类别不平衡数据中学习到的有偏模型在召回率方面确实倾向于多数类别,但在精确度方面则更倾向于少数类别。对其他SSL算法以及全监督类不平衡学习也进行了类似的观察。这一实证发现促使我们利用少数群体的高精确度来减轻他们的召回率退化。为了实现这个目标,我们引入了CReST,这是一个类重新平衡的自我训练框架,如图3所示。

3.3. Class-rebalancing self-training

    self-training是SSL中广泛使用的一种迭代方法。它对模型进行多代训练,每次迭代涉及两个步骤。首先,在标注集上对模型进行训练,得到教师模型。其次,教师模型的预测被用来为未标记的数据生成伪标签ˆy_{m}。将伪标记集Uˆ=(u_{m},ˆy_{m})包含在标记集内,即X=X∪U,作为下一代的ˆU。

     为了适应班级的不平衡,我们对自我训练策略提出了两种修改。首先,我们使用SSL算法来利用已标记和未标记的数据,从而在第一步中获得一个更好的教师模型,而不是仅仅对已标记的数据进行培训。更重要的是,在第二步中,我们并不是将ˆU中的每个样本都包含在标记集中,而是将标记集展开成一个选定的子集ˆS⊂ˆU,即X=X∪ˆS。我们选择ˆS,遵循一个类平衡规则:类l出现的频率越低,预测为类l的未标记样本就会被包含到伪标记集ˆS中。

    我们从标记集估计类分布。其中,预测为l类的未标记样本以速率被纳入ˆS

                           

    其中,α≥0调节采样率,从而调节S的大小ˆ。例如,对于一个10类不平衡数据集,不平衡比为γ= N1/ N10 =100,我们将所有样本预测为最少数类,因为\mu _{10} =(N_{10}+1−10)/N_{1}^α =1。而对于大多数类别,样本的µ1 =(N10+1−1)/N1^α =0.01^α。当α=0,\mu _{l} =1时,保留所有未标记的样本,算法恢复到常规的自训练。在每个类中选取伪标签样本时,我们选取最自信的样本。

    我们的CReST战略的动机是双重的。首先,如第3.2节所观察到的,少数类的精度要比多数类的精度高得多,因此少数类伪标签包含在标记集中的风险更小。其次,由于数据稀缺,将少数群体添加到样本中更为关键。有更多的少数类别样本添加后,标签集的类均衡程度更高,从而使得后续生成的在线伪标注分类器的偏倚更小。请注意,还有其他以阶级平衡的方式对伪标签进行采样的方法,我们提供了一个实用而有效的例子。

方法(Method)

作者 follow 半监督学习中 self-training 的过程:

  • 使用标准的 SSL 算法利用已标记集和未标记集的信息训练一个有效的模型

  • 给未标记集  中的每个样本打上伪标记得到新的数据集 

  • 「挑选出模型的预测类别属于尾部类别的样本作为候选集  加入到已标记集合中」

最妙的一步在第三步,「模型预测的类别属于尾部类别意味着这些样本的伪标记具有很高的置信度的(High precision),因为此时的模型是对头部类别过拟合的,此时模型还将某一样本预测为尾部类别说明该伪标记真的是该样本的 ground-truth。从另一方面,这一采样又巧妙的引入了尾部类别样本,从而缓解了类别不平衡问题。」

3.4. Progressive distribution alignment

     我们进一步提高了在线伪标签的质量,在CReST中增加了渐进分布对齐,并将其区分为CReST+。

    虽然最初是为类平衡的SSL引入的,但分布对齐(DA)[1]特别适合于类不平衡的场景。它将模型在未标记样本上的预测分布与标记训练集的类分布p(y)对齐。Let ~ p(y)是模型对未标记示例的预测的移动平均。DA首先缩放模型的预测q =p(y|um;(F)对于未标记的示例um,用p(y) p属于(y),将q与目标分布p(y)对齐。

     为了进一步增强DA处理类平衡数据的能力,我们对其进行了温度标度扩展。具体来说,我们增加了一个调谐旋钮t∈[0,1],它控制DA的类再平衡强度。我们不直接以p(y)为目标,而是使用温度尺度分布的归一化(p(y)t)。当t=1时,我们得到DA。当t<1时,温度尺度分布变得更平滑,更积极地平衡模型的预测分布。当t=0时,目标分布一样。

    虽然使用较小的t可以在类平衡测试标准下使单个代受益,但它对多代自我训练不太理想,因为它影响伪标签的质量。具体来说,应用t<1使得模型的预测分布比训练集的类分布更均衡,从而使模型更频繁地预测少数类。然而,在少数类样本较少的非平衡训练集上,这种伪标记倾向于过平衡,即错误地将更多的样本预测为少数类。这降低了少数类的高精度,干扰了我们利用它来产生更好的伪标签的能力。

    为了解决这个问题,我们建议通过在每代中减少t来逐步增加职业再平衡的力量。具体来说,我们用当代g的一个线性函数来设置t,该函数从0开始:

                              

    其中G+1是总的迭代数,tmin是最后一代使用的温度。这种t的进步性时间表在早期具有较高的伪标签精度,在后期具有较强的类平衡。它还加快了迭代训练的速度,以更少的训练次数获得更好的结果。实证分析见4.3节。

4. Experiments

4.1CIFAR-LT

Datasets:我们首先对文献中介绍的长尾CIFAR10 (CIFAR10- lt)和长尾CIFAR100 (CIFAR100- lt)的有效性进行了评估。在这些数据集上,每个类随机丢弃训练图像,以保持预定义的不平衡比例γ。其中,Nl =\gamma ^{-\frac{l-1}{L-1}}·N1, CIFAR10-LT  N1为5000,L=10; CIFAR100-LT为N1 =500, L=100。我们从训练数据中随机选取β =10%和30%的样本来创建标记集,并测试CIFAR10-LT的不平衡比γ=50、100和200,CIFAR100-LT的不平衡比γ=50和100。测试集保持原状和平衡,为了使评估标准,测试集上的准确性,是类平衡的。

Setup:我们使用Wide ResNet-28-2跟随作为骨干。我们在FixMatch和MixMatch上评估我们的方法。对于每一代,当使用FixMatch作为基线SSL算法时,模型将接受216个步骤的训练,而MixMatch将接受217个步骤的训练。我们使用余弦学习速率衰减,其公式在补充材料中提供。每个训练生成的其他超参数是不受影响的。对于CReST和CReST+相关的超参数,我们为FixMatch设置α=1 / 3, tmin =0.5,为MixMatch设置α=1 / 2, tmin =0.8。CReST需要15代,而CReST+只需要6代,通过渐进分布对齐加速。超参数的选择基于CIFAR10LT的单次折叠γ=100和β =10%。我们每210步评估测试数据集上的模型,并报告最后5个评估的平均测试精度。每个算法在标记数据的5个不同折叠下进行测试,我们报告了测试集上准确性的平均值和标准差。在[2]和[39]之后,我们使用模型参数的指数移动平均来报告最终性能。

Main results:首先,我们将我们的模型与基线FixMatch进行比较,并在表1中显示结果。虽然FixMatch在不平衡比γ=50上表现良好,但其精度随着不平衡比的增加而显著下降。相比之下,CReST提高了FixMatch在所有评估设置上的准确性,并实现了高达9.6%的绝对性能增益。当合并渐进分布对齐时,我们的CReST+模型能够进一步提高所有设置上的性能几个点,与基线FixMatch相比,其绝对精度提高了3.0%至11.8%。所有比较方法的准确性都随着标记样本数量的增加而提高,但CReST始终优于基线。这说明CReST在类分布不平衡的情况下,可以更好地利用标记数据来减少模型偏差。我们还观察到,我们的模型工作得特别好,当不平衡比γ=100时,在标记数据为10%和30%的情况下,分别达到11.8%和6.1%的准确性增益。我们假设原因是我们的模型发现更多正确的伪标记样本来扩大标记集。然而,当不平衡比很高时,如γ=200,我们的模型的能力受到少数类训练样本数量不足的限制。

Comparison with baselines:我们在表2中进一步报告了其他SSL基线的性能。为了便于比较,所有算法都经过了6×2^16步骤的训练。这导致FixMatch基础上的CReST和CReST+有6代,每代216步,MixMatch基础上的CReST和CReST+有3代,每代217步。其他不使用自我训练的模型使用6×2^16步骤进行单一代的训练。

                 

我们将CReST和CReST+与基线方法进行比较,包括不同的SSL算法和为全监督学习设计的典型类平衡技术。为了进行公平的比较,所有模型都按照相同的训练步骤进行测量。详情请参阅文本。3个不平衡比率γ与β =10%标签评估。数字在5个不同的折叠中平均。

                      

CIFAR10上DARP协议[22]下的准确度(%)。有关数据集的详细信息,请参阅补充材料。计算了三种不平衡比γ。在5次运行中取平均值。

                    

在ImageNet127上用β =10%的样本进行评估。我们用我们的CReST和CReST+重新训练FixMatch模型3代。

      我们首先直接评估了在类不平衡数据集上的几种经典SSL方法,包括Pseudo-Labeling[26]、Mean Teacher[43]、MixMatch[2]和FixMatch[39]。由于数据的不均衡,SSL基线的精度普遍较低,且不均衡比例越高,准确率下降越明显。在MixMatch上,CReST提供的改进是适度的,这主要是由于调度约束。提供更多的生成预算,MixMatch与CReST的结果可以进一步改进。在这些算法中,FixMatch的性能最好,因此我们将其作为各种再平衡方法的基线。

Comparison with DARP:我们直接与DARP[22]进行比较,它是专为不平衡数据设计的最新最先进的SSL算法。DARP和我们的方法都建立在MixMatch和FixMatch的基础上,作为标准SSL算法的添加。我们将我们的方法应用于DARP中使用的完全相同的数据集,并在表3中给出了结果。有关数据集构造的详细资料在补充资料中提供。对于这三种不平衡比率,我们的模型在MixMatch上比DARP始终达到4.0%的精度增益,在FixMatch上达到2.4%的精度增益。

4.2. ImageNet127

Datasets:我们还在ImageNet127上评估CReST,以验证其在大规模数据集上的性能。ImageNet127最初是在中引入的,其中ImageNet[11]的1000个类在WordNet中根据它们自顶向下的层次结构被分组为127个类。它是一个不平衡比γ=286的自然不平衡数据集。它最主要的类“哺乳动物”由218个原始类和277,601个训练图像组成。而其最少数的班级“蝴蝶”则是由一个单一的原创班级组成,有969个训练例子。我们随机选取β =10%的训练样本作为标记集,并保持测试集不变。由于类分组的原因,测试集不均衡。因此,我们计算平均类召回率而不是准确性,以实现一个平衡的指标。 

  我们注意到,还有其他大型数据集,如iNaturalist[10]和ImageNet-LT[29],它们经常作为全监督长尾识别算法的测试平台。然而,这些数据集包含的少数类的例子太少,无法形成统计上有意义的数据集,也无法为半监督学习得出可靠的结论。例如,在ImageNet-LT数据集的最少数类中只有5个示例。

Setup:我们使用ResNet50[15]作为骨干。每个训练生成的超参数都来自原始的FixMatch文件。在α=0.7, tmin =0.5的条件下,进行了3代自训练。

Results:我们在表4中报告结果。给出了带有100%和10%标记训练示例的监督学习和带有温度标度的DA。与基线FixMatch相比,经过3代自训练后,CReST和CReST+均有逐步提高,而CReST+最终提供了7.9%的绝对增益,验证了我们所提方法的有效性。

4.3. Ablation study

     我们进行了一项广泛的消融研究,以评估和了解每个关键成分在CReST和CReST+中的作用。本节的实验均使用FixMatch在CIFAR10-LT上进行,不平衡比γ=100,标签分数β =10%,标签数据为一倍。

     抽样率的影响:CReST引入了采样率超参数α,该参数控制每个类的采样率和被选中的伪标记样本被包含在标记集中的数量。在图4中,我们展示了α如何影响几代人的性能。当α=0时,我们的方法回归到传统的自训练,将所有未标记的例子及其对应的预测标记展开标记集。然而,传统的自我训练不会在几代人之后产生性能增益,这表明简单地应用自我训练并不能提供性能改进。而我们的类再平衡抽样策略(α>0)则通过迭代模型再训练来提高精度。

     如图4(a)所示,α值越小,表示被标记集合中加入的伪标记样本越多,这就扩大了被标记集合,但反过来引入了更多低质量的伪标记。另一方面,较大的α值使得伪标签样本倾向于少数群体。因此,当α值较大时,类别再平衡抽样可能过于强大,导致向相反方向的不平衡,朝向原始的少数群体。这是α=1的情况,其中,在第10代之后,模型变得越来越偏向少数群体,并遭受性能下降的多数群体,导致准确性下降。例如,从第10代到上一代,大多数少数类的召回率从55.0%大幅上升到71.1%,而其他9个类中有7个类的召回率严重下降,导致类平衡测试集准确率下降了3.0%。实证研究发现,α=1 / 3在CIFAR长尾数据集上实现了伪标签质量与不同不平衡比例和标签分量的分类再平衡强度之间的平衡。

             

α跨代对CReST CIFAR10LT (γ =100, β =10%)的影响。(a)说明α如何影响采样率。(b)不同α世代的检测精度。当α=0时,该方法返回到传统的自我训练,将所有未标记的例子和相应的伪标签添加到标记集中,经过几代再训练后没有任何改善,而我们的类再平衡抽样(α>0)有所帮助。

             

多代温度t对CIFAR10-LT (γ =100, β =10%)的影响。(a)说明t如何控制分布对齐的目标分布。(b)使用不同的常数t和我们的CReST+进行几代的测试精度。与使用常数t相比,在6代的时间内,从t=0到tmin =0.5, CReST+获得了最好的最终精度。

Effect of progressive temperature scaling.在CReST+中使用的自适应分布对齐引入了另一个超参数,温度t,用于缩放目标分布。我们首先在图5(a)中说明了温度t如何在分布对齐中平滑目标分布,使较小的t提供更强的再平衡强度。在图5(b)中,我们研究了使用恒定温度和我们提出的渐进温度标度的影响,在每代自我训练中,t从1.0逐渐减小到tmin =0.5。

    首先,我们发现t=0.5在所有测试温度值中提供了75.1%的单代精度。这表明,与初始分布对齐(适度t固定为1.0)的70.0%精度相比,通过适当的“平滑”目标分布,模型可以受益于类再平衡。进一步降低t到0.1会导致更低的精度,因为目标分布过于平滑,引入更多的伪标记误差。

    经过几代的自我训练,使用一个常数t不是最佳的。虽然一个相对较小的t(例如0.5)可以在早期提供更好的性能,但由于伪标签质量下降,它不能通过继续自我训练提供进一步的收益。当t低于0.5时,性能甚至会在后续几代之后下降。相比之下,本文提出的CReST+逐步增强分布对齐强度,提供了上一代最佳的精度。

Per-class performance.为了显示准确率提高的来源,在表5中,我们展示了CIFAR10-LT平衡测试集的每类召回率,不平衡率为100,标签分数为10%。CReST和CReST+在4个多数职业上牺牲了一些精度点,但在其他6个少数职业上提供了显著的增益,获得了比所有职业更好的性能。我们也包括关于非平衡无标记集的结果。结果与测试集的结果特别相似,多数阶级的成绩轻微下降,少数阶级的成绩显著提高。这表明我们的方法确实提高了伪标签的质量,它可以在一个平衡的测试标准上转化为更好的泛化。

5. Conclusion

       在这项工作中,我们提出了一个类重新平衡的自我训练框架,命名为CReST,用于不平衡半监督学习。CReST的动机是观察到现有的SSL算法在少数类上产生高精度的伪标签。CReST通过用高质量的伪标签补充标签集,迭代地细化基线SSL模型,其中少数类比多数类更新更积极。经过几代的自我训练,这种模式变得不那么偏向多数阶级,而是更多地关注少数阶级。我们还扩展了分布对齐,以逐步增加它的类再平衡强度,并表示组合的方法CReST+。在CIFAR长尾数据集和ImageNet127数据集上的大量实验表明,提出的CReST和CReST+大大改进了基线SSL算法,并始终优于最先进的再平衡方法。

6.启发

动机(Motivation)

本文的问题设置更为复杂,考虑的是半监督场景下的长尾分布问题,「即此时我们不仅没有足够的有标记样本,而且这些有标记样本的分布还是长尾分布的(类别不平衡的)」。面对这么困难的问题,作者倒是不慌不忙,首先做了一个很有意思的实验。

作者使用 「FixMatch」 模型 (一个解决半监督问题的SOTA方法) 分别在具有长尾分布的「CIFAR10-LT」 (左边两张图) 以及 「CIFAR100-LT」 (右边两张图) 上进行了实验。其中横坐标代表长尾分布的不同类别,越小的数字代表是头部类别,越大的数字代表是尾部类别;纵坐标对应红点和蓝点分别是 Recall 和 Precision。

实验现象表明,「模型对头部类别的样本 Recall 很高,对尾部类别的 Recall 很低;模型对头部类别样本的 Precision 很低,但对尾部类别的 Precision 却很高」。这是一个很常见的类别不平衡问题里的过拟合现象,换句话来说,「模型对不确定性很高的尾部类别样本都预测成头部类别了。」

举个例子,我在训练阶段喂入模型100张猫的图片以及10张狗的图片,在测试阶段时会发现对于模型把握不准的狗的图片都会预测成猫,只有模型特别有把握的狗的图片才会预测成狗,此时会造成猫这个类别的 Recall 会非常高 Precision 却会非常低,反之狗这个类别的 Recall 会非常低但 Precision 却会非常高。

这个实验现象是符合直观的,但是怎么来运用上述这一信息呢?

7.讨论

   利用了模型学习长尾分布样本表现出来的规律,「既利用了未标记样本的真实标记,又利用了尾部类别的样本。」

后面深度思考了一下这件事:

  • 这些被挑选出来的样本虽然有很大的可能具有正确的伪标记,但它可能不太具备代表性,即不能很好的代表这个类。换句话来说,模型对这些样本具有很大的置信度,即这些极可能是简单样本,对模型的学习帮助性可能不大,因此此时模型已经很确信能将其预测对了,此时再引入这些样本的loss其实很小,对模型的影响也不大。

  • 针对前面所提到的,所以我认为可能性能的提升绝大部分来自于类别平衡了,当然正确的简单样本的引入也会对模型性能提升有帮助。

  • 这个方法由此也会在半监督场景下作用明显,因此本来就没啥有标记样本,还如此的类别不平衡,此时给一些正确标记的虽然简单的样本对模型训练也是很有帮助的。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值