半监督学习——FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence

论文:https://arxiv.org/abs/2001.07685

代码:https://github.com/google-research/fixmatch

1. 论文题目与摘要

                               FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence

       摘要:半监督学习有效的利用没有标注的数据,从而提高模型的精度。这篇论文,我们将有效的结合两种常见的半监督学习方法:一致性正规化技术和伪标签技术。我们的算法叫做FixMatch,首先把没有标签的图片进行轻微的数据增强,用模型对怎强后的图片进行预测,从而生成为标签。对于每张没有标签的图片,当模型的预测得分高于一定的阈值时,伪标签才起作用。模型预测伪标签的同时,将同样的图片进行强烈的数据增强送入网络,计算损失。虽然方法看起来简单,但是FixMatch在从多的半监督学习方法中达到了最好的效果。仅用了250张标注数据,在CIFAR-10数据集上达到了94.93%的准确率;仅用了40张标注数据,在CIFAR-10数据集上达到了88.61%的准确率(每个类别只取了4张标注数据);因为作者做了很多消融实验,说明不同因素对半监督学习效果的影响,最终FixMatch这种半监督学习方法获得成功。我们的代码已经开源:https://github.com/google-research/fixmatch.

 2. 算法主要流程       

Caption

             首先,图片进行轻微的数据增强,然后输入网络进行预测,生成独热编码的为标签。然后,把同样的图片进行强烈的数据增强,得到预测特征。如果轻微数据增强的预测得分大于一定的阈值,那么生成的为标签就和强烈数据增强的特征计算交叉熵损失。整个过程如上图所示:

3. 实现细节

            从整体来看,FixMatch算法是两种半监督学习算法的简单结合,即一致性正则化技术和伪标签技术。

            FixMatch的损失函数有两部分组成:有标签的图片用有监督的损失Ls,没有标签的图片用无监督的损失Lu, 两个损失都是标准的交叉熵损失。

            首先,看看有监督的损失函数,标准的交叉熵损失函数:

           再看看对于没有标签图片的处理:首先得到伪标签,如果伪标签的得分大于一定的阈值(τ,论文中的阈值取0.95),那么,就用该伪标签和强烈数据增强获得的特征计算交叉熵损失:           

           最后,FixMatch的损失函数为:Ls + λ * Lu, 其中λ是一个超参数,用来平衡两个损失函数的,论文中λ=1。

           论文中超参数的设置如下:

            其中:μ为无标签图片和有标签图片的比例。

            模型训练的伪代码如下图所示:

           当然,作者还做了很多消融实验,例如:调节学习率、选择优化器等等。作者的工作量还是挺大的,但创新点就那么多。

4. 总结

           作者提出了一种简单的半监督学习算法:FixMatch,该半监督学习算法在多个数据集上达到了最先进的的结果。FixMatch搭建了low-label semi-supervised learning 和 few-shot learning的联系,甚至聚类算法。作者每一类仅用一张有标签图片,就获得了很高的准确率。对于有标签和无标签的图片,由于Fixmatch用标准的交叉熵损失函数,所以Fixmatch训练工程仅用几行代码就可以完成。

 

以上是博主对论文的理解,如需讨论,请留言!

要复现FixMatch代码,可以按照以下步骤进行操作: 1. 数据准备:首先,需要准备训练数据集和标注数据集。训练数据集可以是无标签的大型数据集,而标注数据集可以是相对较小的有标签数据集。确保数据集的准备工作已经完成。 2. 构建模型:根据FixMatch论文的说明,构建一个基础模型。可以选择使用图像分类的常见模型,如ResNet、VGG等作为基础模型。 3. 数据增强:为了增加模型的鲁棒性和泛化能力,需要对数据进行增强。可以使用图像处理库,如OpenCV或PIL,来进行旋转、翻转、剪裁等操作。 4. 伪标签生成:使用基础模型对无标签数据集进行预测,并从预测结果中选择置信度较高的样本。将这些样本与其对应的预测结果作为伪标签。 5. 训练过程:使用有标签数据和伪标签数据构建训练集,并使用交叉熵损失函数进行模型训练。可以选择使用SGD或Adam等优化器,并设置适当的学习率和超参数。 6. 批量增强策略:为了进一步提高模型的性能,可以使用批量增强策略。例如,可以每个批次随机选择一部分无标签数据,并根据固定的数据增强策略对其进行增强,以增加数据样本的多样性。 7. 迭代训练:重复执行第4至第6步,直到达到预设的训练轮数或收敛条件。 8. 模型评估:使用测试集对训练好的模型进行评估,计算准确率、精确率、召回率等指标。 以上是一个大致的复现FixMatch代码的流程,具体的实现细节和超参数的选择需要根据实际情况进行调整。
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值