FixMatch(ReMixMatch的大大简化版本)
对于无标签样本前人的做法
基于一致性正则化
最小化交叉熵损失或均方误差损失,使得模型对于同一样本的两种不同数据增强的预测结果尽可能地接近。
基于伪标签
用有标签样本训练好的模型为无标签样本打伪标签qb,设定一个置信度阈值τ,选取伪标签置信度大于阈值τ的样本,用其伪软标签qb和argmax(qb)得到的one-hot标签qb^(伪硬标签)计算交叉熵损失,使得模型对所做出的预测尽可能地自信(熵最小化)
作者的做法(对于无标签样本的处理结合了上述两种思想)
弱数据增强(flip裁剪、shift平移);强数据增强(RandAugment、CTAugment、Cutout),Cutout默认地用于RandAugment和CTAugment之后(各数据增强方法的参数设置见原论文2.3节)
FixMatch的损失包含两部分,ls是对有标签样本而言的、lu是对无标签样本而言的。
ls是对有标签样本的损失,是标准的交叉熵损失,使得模型对弱数据增强的预测结果和原样本的标签pb尽可能地接近。
α(xb)代表对样本xb做弱数据增强,B为batchsize,pm为模型。
lu是对无标签样本的损失
qb是弱数据增强样本的伪软标签,argmax(qb)得到的one-hot标签qb^ 是弱数据增强样本的伪硬标签。A(xb)代表对样本xb做强数据增强,μB为batchsize,pm为模型。
和前人的做法相同的是:都是只取伪标签的置信度大于阈值τ的无标签样本去算交叉熵。
和前人的做法不同的是:该式使得模型对同一样本的强数据增强的预测结果和弱数据增强的预测结果的伪硬标签尽可能地接近。是模型对两种不同的数据增强给出的结果去算交叉熵,而不是对一个数据增强的软伪标签和对应的硬伪标签算交叉熵。具体的流程如下图所示:
整体的损失表示为ls+λulu(λu是超参数、控制无标签样本在模型训练中的相对权重)
作者的发现
- 半监督性能可能会受到SSL算法之外的其他因素的影响,因为诸如正 则化之类的考虑在低标签环境中可能尤为重要。优化器、网络架构、学习率上的选择也同等重要。
- 在本文的所有的模型和实验中,使用简单的权重衰减正则化(L2正则化)
- 我们还发现,使用亚当优化器[22]导致性能变差,转而使用带动量的标准SGD
- 我们使用余弦学习率衰减[28]将学习速率设置为ηcos(7πk/16K)其中η是初始学习速率,k是当前训练步骤,K是训练步骤总数
- 我们使用指数移动平均更新模型参数。
实验细节
- 阈值τ这一超参数的最优值是0.95
- 阈值0.95显示了最低的错误率,尽管将其提高到0.97或0.99并没有造成太大影响。相反,使用较小的阈值时,准确度下降了超过1.5%。请注意,阈值控制了伪标签的质量和数量之间的权衡。无标签数据的伪标签准确性随着阈值的提高而增加,在0.95时达到峰值,但阈值设的越高,参与计算交叉熵的伪标签样本就越少,这表明,达到高准确度时,伪标签的质量比数量更重要。另一方面,当使用置信度阈值时,锐化在性能上没有表现出显著差异,因此用了置信度阈值来筛选高置信度的伪标签样本就不用再锐化(sharpening)了。
- 我们发现,Cutout和CTAugment都采用才能获得最佳性能;删除这两个中的任一个都会导致出错率显著增加
- 我们还研究了用于生成伪标签和预测(即上述流程图中的上下路径)的弱和强增强的不同组合。当我们将用于标签猜测的弱增强替换为强增强时,我们发现模型在训练早期发散。相反,当将弱增强替换为无增强时,模型对无标签标签过拟合。使用弱增强代替强增强来生成模型的训练预测,准确率最高达到45%,但不稳定,并逐步下降到12%。这表明强数据增强的重要性。
不成熟的想法
-
阈值τ改成一个随epoch增加而不断变大的参数会不会更好?
-
通篇未使用mixup这一经典的数据增强方法,使用mixup会不会更好?
-
FixMatch给出的整体损失是有监督的损失项和无监督的损失项相加,意思是无标签的样本和有标签的样本在一个epoch中同时参与训练?
-
把两个损失项拆开,先只用第一个损失项有监督地训练模型,然后用训练好的模型给无标签样本打伪标签,再用第二个损失项训练模型效果会不会更好?