Fixmatch:用一致性和置信度简化半监督学习
FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence
- 作者:Kihyuk Sohn∗ David Berthelot∗
- 单位:Google Research
- 发表时间:2020年
- 来源:NeurIPS 2020(国际顶级会议)
- code:https://github.com/google-research/fixmatch
- 数据集:CIFAR-10
- 参考论文:https://arxiv.org/abs/2001.07685
1.要解决什么问题?
半监督学习是利用未标记数据的有效方法,但是目前的许多实现的SSL方法都过于复杂,本文提出了一种显著简化现有SSL方法的算法,即Fixmatch.
2.解决方法?
新的结合一致性正则化和伪标签的半监督学习方法
具体操作:首先把没有标签的图片进行轻微的数据增强(弱增强),用模型对增强后的图片进行预测,从而生成为标签。对于每张没有标签的图片,当模型的预测得分高于一定的阈值时,伪标签才起作用。模型预测伪标签的同时,将同样的图片进行强烈的数据增强(强增强)送入网络,计算损失。
首先,图片进行轻微的数据增强,然后输入网络进行预测,生成独热编码的为标签。然后,把同样的图片进行强烈的数据增强,得到预测特征。如果轻微数据增强的预测得分大于一定的阈值,那么生成的为标签就和强烈数据增强的特征计算交叉熵损失。整个过程如上图所示:
3.损失函数
从整体来看,FixMatch算法是两种半监督学习算法的简单结合,即一致性正则化技术和伪标签技术。 FixMatch的损失函数有两部分组成:有标签的图片用有监督的损失Ls,没有标签的图片用无监督的损失Lu, 两个损失都是标准的交叉熵损失。
主要创新性来自于一致性正则和伪标签两种成分的结合,以及在执行一致性正则化时分别使用弱和强增强。
有监督的损失函数,标准的交叉熵损失函数:
对于没有标签图片的处理:首先得到伪标签,如果伪标签的得分大于一定的阈值(τ,论文中的阈值取0.95),那么,就用该伪标签和强烈数据增强获得的特征计算交叉熵损失:
最后,FixMatch的损失函数为:Ls + λ * Lu, 其中λ是一个超参数,用来平衡两个损失函数的,论文中λ=1。
4.实现细节
论文中超参数的设置如下:
其中:μ为无标签图片和有标签图片的比例。
模型训练的伪代码如下图所示:
5.实验
数据集:CIFAR10/100, SVNH
除了ReMixMatch外,没有其他工作考虑每个类仅有少于25的标注数据,本文考虑了每个类仅有4个标注样本的情况。
除了cifar100上ReMixMatch最佳,其他都是FixMatch最佳。在将ReMixMatch的各成分移植到FixMatch的过程中,发现最重要的一项是分布对齐(Distribution Alignment),它鼓励模型以等概率emit所有类别。将DA组合到FixMatch达到了超越ReMixMatch本身的效果。
FixMatch使用RandAugment和CTAugment在大部分情况下性能相近,除了每类只有4个标注样本的情况,这可以被结果存在高方差解释。
6.消融实验
研究了温度T和置信阈值τ之间的相互作用。
对不同的强数据增强策略进行了消融研究。
作者还做了很多消融实验,例如:调节学习率、选择优化器等等。
7.总结
作者提出了一种简单的半监督学习算法:FixMatch,该半监督学习算法在多个数据集上达到了最先进的的结果。FixMatch搭建了low-label semi-supervised learning 和 few-shot learning的联系,甚至聚类算法。作者每一类仅用一张有标签图片,就获得了很高的准确率。对于有标签和无标签的图片,由于Fixmatch用标准的交叉熵损失函数,所以Fixmatch训练工程仅用几行代码就可以完成。