《MixMatch: A Holistic Approach to Semi-Supervised Learning》论文阅读报告

1. 算法核心思想
1.1 基本思想

现有的半监督学习方法主要有三种:自洽正则化(Consistency Regularization),最小化熵(Entropy Minimization)和传统正则化(Traditional Regularization)。而MixUp同时兼具了这三种方法的优点:集成了自洽正则化,在图像数据增广中使用了对图像的随机左右翻转和剪切(crop);使用“sharpening”函数,最小化未标记数据的熵;使用了单独的权重衰减并使用MixUp作为正则化器(应用于标记数据点)和半监督学习方法。

MixMatch的伪代码如下图所示(图一),接下来将按照步骤详细介绍MixMatch的每一个部分。

图一 MixMatch算法伪代码

1.2 关键步骤

1.2.1 数据增强

同时对有标记数据和无标记数据做增强。对一个Batch的有标记数据X和一个Batch的无标记数据U做数据增强,对X做一次增强且标签不变,而对U做K次。 

1.2.2 标签猜测

将增强后的未标注数据输入预测模型,模型生成“猜测”标签。为一个Batch中的每一个未标记数据ub的K个增强的猜测标签计算平均值(伪代码第七行所示):

       使用Sharpen 算法对上式得到的标签进行处理,得到标签qb。Sharpen 算法具体操作如下:

       其中,T为超参数,当T趋近于0时,Sharpen(p, T)i 接近于One-Hot 分布,即对某一类别输出概率1,其他所有类别输出概率0,此时分类熵最低。这很好理解,比如在猫狗二分类中,分类器说,这张图片中50%的几率是猫,50%的几率是狗,对各类别分类概率预测比较平均;而使用Sharpen来使得“伪”标签熵更低,即猫狗分类中,要么百分之九十多是猫,要么百分之九十多是狗。

图二 标签猜测与Sharpen过程

       从图中Average到Sharpen的变化也可以看出该操作的作用:使得“伪”标签熵更低,使输出接近于One-Hot 分布

1.2.3 MixUp

       将前两步得到的所有数据增强之后的带标签数据及它们的标签、所有未标注数据及其“猜测”标签整合成以下集合:

       将混合在一起,随机重排得到数据集。最终输出将做了MixUp() 的一个 Batch 的标记数据,以及做了MixUp() 的 K 个Batch 的无标记增广数据

       与之前的Mixup方法不同,MixMatch方法将标记数据与未标记数据做了混合,进行 Mixup。对于两个样本及它们的标签(x1p1), (x2p2),混合后的样本为:

       其中,权重因子λ’使用超参数α通过Beta函数抽样得到:

       关于这个对MixUp的修改,作者给出的解释是需要保持每个Batch中的顺序。这样的操作能让x’更接近于x1而非x2。在Mixup标记数据与混合数据时,这样能增加的权重;在 Mixup 未标记数据时,这样能增加的权重。

1.3 损失函数

       损失函数定义如下:

       其中,对于有标签数据,使用Cross Entropy计算Loss。而对于无标签数据,使用L2 Loss。作者对为何无标签数据不使用Cross Entropy Loss而是L2 Loss做出了解释:因为L2 Loss不像Cross Entropy Loss,它是有界的且对错误的预测不太敏感。在文章引用的[25] Temporal ensembling for semi-supervised learning的第三页提供了更详细解释:Cross Entropy 计算是需要先使用 softmax 函数,将Dense Layer输出的类分数转化为类概率,而softmax函数对于常数叠加不敏感,即如果将最后一个Dense Layer的所有输出类分数同时添加一个常数c, 则类概率不发生改变,Cross Entropy Loss不发生改变。因此,如果对未标记数据使用Cross Entropy Loss, 由同一张图片增广得到的两张新图片,最后一层Dense Layer的输出被允许相差一个常数。而使用L2 Loss, 约束更加严格。

       最终的整体损失函数是两者的加权,其中超参数λu是无监督学习损失函数的加权因子。

1.4 超参数设置

       使用到的超参数包括温度参数T,对未标记数据做增强的次数K,MixUp的Beta函数的α以及无监督权重因子λu。作者在实验中发现,这些超参数中的大多数都是可以固定的,不需要对每个实验或每个数据集进行调优。设置T = 0.5, K = 2,只对不同数据集上的α和λu做调整。开始时可以设置α = 0.75, λu= 100。

2. 实验与结果

       作者主要进行了三类实验:对比实验在标准半监督学习的基准上测试MixMatch的有效性,消融实验验证MixMatch每个部分的贡献,PATE架构验证MixMatch在隐私保护中的应用。

2.1 半监督学习训练结果(对比实验)

2.1.1 实验设置

除非特殊说明,在所有实验中,使用的都是Wide ResNet-28模型。在CIFAR-10和CIFAR-100、SVHN和STL-10这四个数据集上进行评估。对比MixMatch和其他四种半监督方法(Π-Model,Mean Teacher,Virtual Adversarial Training和Pseudo-Label),以及MixUp本身,在四个数据集上的错误率。

2.1.2 实验目的

对比实验,对比MixMatch和现有其他方法(5种)在数据集上的错误率,验证MixMatch方法的高性能。

2.1.3 实验结果

CIFAR-10

       使用从250个到4000个不等的带标注数据来评估每种方法的准确性,由均值和方差反应错误率。设置λu= 75。

表一 六种方法在CIFAR-10上的错误率

图三 六种方法在CIFAR-10上的错误率(折线图)

       在 CIFAR-10 数据集上,使用全部五万个数据做监督学习,最低误差能降到百分之4.13。而使用MixMatch,250个数据就能将误差降到11%,4000个数据就能将误差降到6.24%。这表明MixMatch使用很少的标记数据点就能达到媲美有监督学习的效果,这正是半监督学习希望达到的效果。此外,从折线图中还可以看到Mean Teacher的错误率的方差是比较大的,中心实线附近还有一大片浅绿色的区域,那片区域就代表算法的表现容易震荡,不稳定。而对比就可以看出MixMatch不仅做到精度最优,还能保证算法本身的稳定性(黑色旁边浅黑色的区域很小)

CIFAR-10及CIFAR-100(使用更大的模型):

       为了与先前工作的结果有合理的比较,使用了有2600万个参数的有28层的Wide ResNet模型。

表二 CIFAR-10及CIFAR-100在大模型上的错误率

       由于使用大模型,只将MixMatch和Mean Teacher和SWA做了对比。可以看出MixMatch和先前工作相比,效果相匹配或优于先前工作的最佳结果。

SVHN 及 SVHN+Extra:

       和CIFAR-10类似,使用从250到4000个不等的标签数量来评估SVHN上每个方法的性能。设置λu= 250,α = 0.25。

表三 SVHN上的错误率

图四 SVHN上的错误率(折线图)

表四 SVHN+Extra上的错误率

图五 SVHN+Extra上的错误率(折线图)

表五 MixMatch在SVHN及SVHN+Extra上的错误率

       SVHN+Extra是将SVHN的额外训练集也组合起来一起训练,这样未标注样本的比例远远超过标注的样本。从结果来看,对于MixMatch,SVHN+Extra上的错误率明显低于SVHN上的错误率。相比其他方法,MixMatch的错误率明显更低,接近监督学习方法(图四、图五)

STL-10

       STL-10包含5000个训练示例。先前的工作部分使用了全部的5000个标注数据,因此在使用1000/5000个标注的情况下进行对比实验。

表六 STL-10上的错误率

       表六中的方法对比没有没有使用相同的实验设置(即模型),因此很难直接比较结果;然而,因为MixMatch的错误率相当于baseline的1/2,因此也能作为证明MixMatch算法有效的证据之一。设置λu= 50。

2.2 消融实验

由于MixMatch结合了各种半监督学习机制,因此,进行消融实验,通过移除或添加组件,进一步了解是哪些部分使得MixMatch表现更好。

具体来说,评估了以下部分的效果:

1)无标签数据的数据增强的次数K;2)移除温度参数T;3)在生成猜测标签时,使用EMA(与Mean Teacher类似);4)只在有标记数据内,只在无标记数据内进行MixUp,并且不混合使用有标记和无标记的数据;5)使用插值一致性训练(Interpolation Consistency Training),这可以被视为本消融研究的一个特例:只对无标记数据进行MixUp,不使用Sharpen函数,EMA方法用于伪标签生成。

消融实验结果如下:

表七 消融实验结果(CIFAR-10)

       从结果看出,每个部分都对MixMatch的性能有贡献,其中贡献较大的是MixUp以及Sharpen操作,而使用EMA会略微损害MixMatch的性能。

2.3 隐私保护学习与泛化

       用于评估方法的泛化性能。并非文章重点,略。

3. 分析与思考
3.1 创新点与贡献
  1. 提出了一个比较通用的训练框架,集成了现有的几个好方法在一起:集成了自洽正则化,在图像数据增广中使用了对图像的随机左右翻转和剪切(crop);使用“sharpening”函数,最小化未标记数据的熵;使用了单独的权重衰减并使用MixUp作为正则化器(应用于标记数据点)和半监督学习方法。
  2. 集成了比较新的一种数据增强方法MixUp,且对MixUp稍微做了改进,引入权重因子λ’确保混合数据时:生成时,标记数据与混合数据中增加的权重;生成时,未标记数据中增加的权重。
3.2 优缺点分析

3.2.1 方法优点

       提出的MixMatch方法在降低错误率方面效果显著,效果媲美监督学习,优于当时的其他方法。利用很少的标注数据取得媲美监督学习的效果,这正是半监督学习希望实现的。

3.2.2 方法缺点

  1. 时间消耗:对每个无标注数据进行多次增强,时间消耗较大;sharpen函数的处理可能也耗时较大(因为sharpen函数类似于softmax,softmax算子一向比较耗时)。
  2. 对于K个增强的样本都采取同一个标签既不能很好地控制伪标签和真实标签对于算法准确率的贡献比例,也会显得有些不够灵活
3.3 其他思考
  1. 文中的实验只在较小的数据集上进行,没有使用ImageNet这种较大的数据集。
  2. 文中的对比消融实验都比较完备,是从实践角度评价较高的一项工作
  3. 文中只在图像上做了实验,感觉也可以在文本数据上使用MixMatch
  4. 在MixMatch之后的一个工作Unsupervised Data Augmentation (UDA),在 CIFAR10 数据集上,使用4000张标记图片,将预测误差降低到了 5.27%,对比MixMatch,预测误差为 6.24%。UDA主要发明了一个技巧TSA,用于缓慢释放有标签数据的信号,这个跟MixMatch里把无标签数据增强K次,效果是类似的;UDA还集成了最新的数据增广方法,比如 AutoAugment, Back translation;UDA使用KL散度作为损失函数。猜想可不可以将MixMatch和UDA的想法混合在一起,比如使用L2+KL散度作为损失函数,或者在MixMatch数据增强时加入AutoAugment和Back translation,或许对进一步提高准确率有帮助。

参考文献:

  1. Berthelot, D., Carlini, N., Goodfellow, I.J., Papernot, N., Oliver, A., & Raffel, C. (2019). MixMatch: A Holistic Approach to Semi-Supervised Learning. ArXiv, abs/1905.02249.
  2. Laine, S., & Aila, T. (2016). Temporal Ensembling for Semi-Supervised Learning. ArXiv, abs/1610.02242.
  3. Xie, Q., Dai, Z., Hovy, E.H., Luong, M., & Le, Q.V. (2019). Unsupervised Data Augmentation for Consistency Training. arXiv: Learning.
  4. 知乎:超强半监督学习 MixMatch (超强半监督学习 MixMatch - 知乎)
  5. 知乎:MixMatch论文阅读 (MixMatch论文阅读 - 知乎)
  6. 知乎:超强半监督学习MixMatch姐妹篇 Unsupervised Data Augmentation (超强半监督学习MixMatch姐妹篇 Unsupervised Data Augmentation - 知乎)
  7. 知乎:MixMatch 和 UDA比较 (MixMatch 和 UDA比较 - 知乎)
  • 13
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值