【论文笔记】MixMatch: A Holistic Approach to Semi-Supervised Learning

原文链接:https://arxiv.org/abs/1905.02249
本文提出的MixMatch方法结合了之前半监督学习中一系列的有效方法,在仅有少量标注的情况下,在很多数据集上都达到了可以媲美有监督学习的结果。

摘要

半监督学习已被证明是利用未标记数据减轻对大型标记数据集依赖的一个有效地方法。在这项工作中,我们统一了目前半监督学习的主要方法,并产生了一个新的算法—MixMatch。该方法主要通过猜测数据增强的无标数据的低熵标签,并使用MixUp混合有标和无标样例。我们展示了MixMatch在许多数据集和标记的数据量上获得了state of the art的结果。例如,在包含250个标签的CIFAR-10上,我们将错误率降低了4倍(从38%降低到11%),在STL-10上降低了2倍。我们还展示了MixMatch如何帮助实现对差异私有性中精确性和私有性权衡。最后,我们进行消融研究,得出哪些成分是MixMatch取得成功的关键。

介绍

1.现有大多数深度网络的成功依赖于大量的有标数据。对于很多任务收集标注数据很困难,而得到无标数据相对容易。

2.半监督学习(SSL)试图利用无标注的数据来减轻对有标数据的需求。很多SSL方法针对无标注数据增加损失项来使得模型很好的泛化到未见数据上。损失项可分为三类:熵最小化,一致性正则化和一般正则化。

本文介绍了MixMatch利用一个loss将这些方法应用到半监督学习中,有以下贡献:

1.在所有数据集上取得了state-of-the-art。

2.消融实验表明MixMatch效果好于每部分之和。

3.MixMatch对于隐私学习很有效。取得state-of-the-art的同时也保证了隐私性。

相关工作

1.一致性正则化

半监督学习中的一致性正则化利用了这样一个假设,分类器对于数据增强后的的数据的分类分布应该与之前的类别分布一样。损失可以写成下式:
∣ ∣ P m o d e l ( y ∣ A u g m e n t ( x ) ; θ ) − P m o d e l ( y ∣ A u g m e n t ( x ) ; θ ) ∣ ∣ 2 2 ||P_{model}(y|Augment(x);\theta)-P_{model}(y|Augment(x);\theta)||_2^2 Pmodel(yAugment(x);θ)Pmodel(yAugment(x);θ)22
数据增强是随机的,所以希望不同数据增强后的同一张图片能尽可能分到同一类。但这类方法的一大缺点是它们只使用了域特定的数据增强方法。MixMatch通过使用图像的标准数据增强来利用一致性正则化(随机水平翻转和裁剪)

2.熵最小化

半监督学习中,一个基本的潜在假设是分类器的决策边界不应该穿过数据边缘分布的高密度区域。为达到这一目的,需要分类器对于无标注样例输出低熵的预测值。通过加入一个损失项来显式的最小化无标输入的熵: P m o d e l ( y ∣ x ; θ ) P_{model}(y|x;\theta) Pmodel(yx;θ)。伪标签的方法通过对无标样例高置信度的预测值打上硬标签,并进一步带入一个交叉熵损失训练来实现熵最小化。MixMatch也通过对未标记数据的目标分布使用“sharpening”函数来隐式地实现熵的最小化。

3.传统正则化

正则化是指对模型施加约束,使其更难记忆训练数据,从而有望更好地推广到未见数据的一般方法。一种常用方法是增加损失项来惩罚模型参数的L2范数。本文的优化方法为Adam算法,所以使用权值衰减来替代L2损失项。

最近,MixUp方法同时对输入和标签的凸组合训练一个模型,其要求模型对两个输入的凸组合的输出接近每个单独输入的输出的凸组合。我们也在MixMatch中使用了MixUp方法(用于有标数据)。

MixMatch

MixMatch集成了上述方法,给定有标数据集 X \mathcal{X} X和同等大小的无标数据集 U \mathcal{U} U,对有标数据和无标数据进行数据增强分别得到 X ’ \mathcal{X^’} X U ’ \mathcal{U^’} U。它们被分别用来计算有标和无标的损失项,最终Loss如下:
X ’ , U ’ = M i x M a t c h ( X , U , T , K , α ) \mathcal{X^’},\mathcal{U^’}=MixMatch(\mathcal{X},\mathcal{U},T,K,\alpha) XU=MixMatch(X,U,T,K,α)
L X = 1 ∣ X ’ ∣ ∑ x , p ∈ X ’ H ( p , P m o d e l ( y ∣ x ; θ ) ) L_{\mathcal{X}}=\frac{1}{|\mathcal{X^’}|}\sum_{x,p \in \mathcal{X^’}}H(p,P_{model}(y|x;\theta)) LX=X1x,pXH(p,Pmodel(yx;θ))
L U = 1 L ∣ U ’ ∣ ∑ x , p ∈ U ’ ∣ ∣ q − P m o d e l ( y ∣ u ; θ ) ∣ ∣ 2 2 L_{\mathcal{U}}=\frac{1}{L|\mathcal{U^’}|}\sum_{x,p \in \mathcal{U^’}}||q-P_{model}(y|u;\theta)||_2^2 LU=LU1x,pUqPmodel(yu;θ)22
L = L X + λ U L U L=L_{\mathcal{X}}+\lambda_{\mathcal{U}}L_{\mathcal{U}} L=LX+λULU
其中 H ( p , q ) H(p,q) H(p,q)表示分布p和q之间的交叉熵损失, T T T K K K α \alpha α λ U \lambda_{\mathcal{U}} λU是超参数,整个算法流程如下表所示:

1.数据增强

如上文所说,数据增强是减轻缺少有标数据影响的一种方法。类似于大部分半监督学习方法,我们同时对有标和无标数据进行数据增强。对有标数据进行一次数据增强,无标数据进行K次数据增强。这些无标数据增强后得到的结果进行‘laebl guessing’获得 q b q_b qb

2.label guessing

对于单个无标样例,我们计算K次增强后类别预测分布的均值,这个得到的标签带入后续的无监督损失项中。

q b ‾ = 1 K ∑ k = 1 K p m o d e l ( y ∣ u b , k ^ ; θ ) \overline {q_b}=\frac{1}{K}\sum_{k=1}^Kp_{model}(y|\hat{u_{b,k}};\theta) qb=K1k=1Kpmodel(yub,k^;θ)
这个方法在一致性正则化方法中很常见。

3.sharpening

得到了上述label guessing的结果后,使用sharpening方法进行熵最小化处理,如下式:

S h a r p e n ( p , T ) i : = p i 1 T / ∑ j = 1 L p j 1 T Sharpen(p,T)_i := p_i^{\frac{1}{T}}/\sum_{j=1}^{L}p_j^{\frac{1}{T}} Sharpen(p,T)i:=piT1/j=1LpjT1
其中p是类别分布(上述的增强后类别分布的均值),T是超参数。T越趋于0,sharpen的输出就趋向于one-hot。因为后续我们需要使用sharpen的输出作为模型预测的目标值,所以选择较低的T保证了模型可以产生低熵的预测。

4.MixUp

我们同时对有标数据和有label guessing结果的无标数据进行MixUp。我们一开始分别对有标数据和无标数据设置不同的loss,但是这会带来问题。对于一对样例, ( x 1 , p 1 ) (x_1,p_1) (x1,p1), ( x 2 , p 2 ) (x_2,p_2) (x2,p2)我们稍微修改了MixUp方法。通过下式计算得到 ( x ’ , p ’ ) (x^’,p_’) (xp)

λ ∼ B e t a ( α , α ) \lambda \sim {Beta(\alpha,\alpha)} λBeta(α,α)
λ ’ = m a x ( λ , 1 − λ ) \lambda^’=max(\lambda,1-\lambda) λ=max(λ,1λ)
x ’ = λ ’ x 1 + ( 1 − λ ’ ) x 2 x^’=\lambda^’x_1+(1-\lambda^’)x_2 x=λx1+(1λ)x2
p ’ = λ ’ p 1 + ( 1 − λ ’ ) p 2 p^’=\lambda^’p_1+(1-\lambda^’)p_2 p=λp1+(1λ)p2
传统的MixUp可以被看做省略了第二项,即 λ = λ ’ \lambda=\lambda^’ λ=λ。收集所有的有标和无标和label guessing结果使用MixUp。

我们将两部分串联起来并shuffle形成MixUp所需的数据源,对第i个有标样例,计算 M i x U p ( X ^ i , W i ) MixUp(\hat \mathcal{X}_i,W_i) MixUp(X^i,Wi)并加入 X ′ \mathcal{X'} X集合中。由于我们的修改,MixUp的结果应该更接近原始有标数据而不是插值的结果。用剩余的W来计算 U ′ \mathcal{U}' U

据此,MixMatch将 X \mathcal{X} X转变为了 X ′ \mathcal{X}' X,一个包含数据增强后的有标数据和与无标数据MixUp结果的集合。相应的, U \mathcal{U} U转变为了 U ′ \mathcal{U}' U,一个对于每个无标样例进行多重数据增强并包含其label guessing的集合。

5.损失函数

获得了 X ′ \mathcal{X}' X U ′ \mathcal{U}' U之后,利用本节一开始的损失函数,对于有标数据,使用传统交叉熵损失,并加上对于 U ′ \mathcal{U}' U中无标数据的标签预测值的平方L2损失。相较于交叉熵,平方L2损失对错分样例有着更低的敏感性。我们不通过猜测的标签传播梯度。

实验部分感兴趣的读者可以参考原文,这里不再赘述。

  • 7
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
MixMatch是一种半监督学习方法,它利用未标记的数据来增强有标记的数据集以提高模型的性能。如果您想使用自己的数据集训练MixMatch模型,可以按照以下步骤操作: 1. 准备数据集:您需要准备一个包含有标记和未标记样本的数据集。有标记样本应该包含输入和相应的标签,而未标记样本只包含输入。您可以根据需求选择合适的数据集,并将其存储在本地或云存储中。 2. 安装依赖项:您需要安装PyTorch和其他必要的Python库,例如NumPy、matplotlib等。 3. 下载MixMatch代码:您需要从Github上下载MixMatch的官方代码,可以通过以下链接获取:https://github.com/google-research/mixmatch 4. 配置参数:您需要打开config.py文件,并根据您的需求修改训练参数,例如批量大小、学习率、训练轮数等。您还需要指定您的数据集路径和其他相关参数。 5. 训练模型:一旦您完成了参数配置,您就可以运行train.py文件开始训练模型。在训练过程中,MixMatch会使用半监督学习方法来利用未标记的数据来增强有标记的数据集。您可以根据需要更改训练参数或停止训练。 6. 评估模型:一旦训练完成,您可以使用test.py文件来测试模型的性能。该文件将输出模型在测试数据集上的准确性和其他相关指标。 这些步骤应该可以帮助您使用自己的数据集训练MixMatch模型。请注意,MixMatch是一种高级算法,需要一定的技术知识和经验才能正确使用。建议您在开始之前仔细学习相关文献和教程,以确保正确使用该算法。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值