MixMatch:半监督学习

1 摘要

半监督学习已被证明是利用未标记数据减轻对大型标记数据集依赖的一个强大范例。在这项工作中,我们结合了目前半监督学习的主流方法,提出了一种新的算法,MixMatch,它利用MixUp方法猜测数据中的低熵标签(low-entropy labels),这些数据包括了数据扩充之后的未标记样本和混合数据(未标记和标记的混合数据)。我们展示了MixMatch在许多数据集和标记的数据量上获得了大量最新的结果。例如,在包含250个标签的CIFAR-10上,我们将错误率降低了4倍(从38%降低到11%),在STL-10上降低了2倍。我们还演示了MixMatch如何帮助实现对差异隐私的更精确的隐私交换。最后,我们进行消融研究,梳理出哪些成分的混合匹配是最重要的成功

2 介绍

最近在训练大型深度神经网络方面取得的成功,在一定程度上要归功于大型标记数据集的存在。然而,对于许多学习任务来说,收集标记数据是昂贵的,因为它必然涉及到专家知识。这一点或许可以从医学任务中得到最好的说明,在医学任务中,使用昂贵的机械和标签进行测量是耗时分析的结果,通常来自多位人类专家的结论。此外,数据标签可能包含被认为是私有的敏感信息。相比之下,在许多任务中,获取未标记的数据要容易得多,也便宜得多

半监督学习(SSL)通过允许模型利用未标记的数据,试图在很大程度上减轻对标记数据的需求。最近的许多半监督学习方法都增加了一个损失项,这个损失项是在未标记的数据上计算的,它鼓励模型更好地泛化至到不可见的数据中。在最近的许多工作中,这个损失项可分为三类:
熵最小化——它鼓励模型对未标记的数据输出有信心的预测;
一致性正则化——当模型的输入受到扰动时,它鼓励模型产生相同的输出分布
泛型正则化——这有助于模型很好地泛化,避免对训练数据的过度拟合。

在本文中,我们引入了MixMatch,这是一种SSL算法,它引入了单个损失,将这些主要方法优雅地结合到半监督学习中。与之前的方法不同,MixMatci rget一次获得所有属性,我们发现它有以下好处:

  • 实验表明,MixMatch在所有标准的图像基准测试(第4.2节)上都获得了最先进的结果,例如,在包含250个标签的CIFAR-10上获得了11.08%的错误率(其次是最佳方法,获得了38%的错误率);

  • 此外,模型简化测试中表明,MixMatch 的效果比各个trick 混合之和要好;

  • 我们在第4.3节中演示了MixMatch对于不同的私有学习是有用的,使PATE框架[34]中的学生能够获得最新的结果,同时增强所提供的隐私保障和所达到的准确性。

简而言之,MixMatch为未标记的数据引入了一个统一的损失项,它无缝地减少了熵,同时保持一致性,并保持与传统正则化技术的兼容性。

3 已有相关工作

为了设置MixMatch,我们首先介绍SSL的现有方法。我们主要关注那些目前最先进的和MixMatch的基础;有很多关于SSL技术的文献我们在这里没有讨论:

  • transductive
  • graph-based methods
  • generative modeling
    下面,我们将引用一个通用模型
    在这里插入图片描述
    y是输入x的分类类别的标签
    x是输入
    theta 是参数

3.1 Consistency Regularization 一致性正则化

在监督学习中,一种常见的正则化技术是数据增强,它应用于对输入进行转换,同时假定这种转换不影响类语义分类。例如,在图像分类中,输入图像通常会发生弹性变形或添加噪声,这可以在不改变图像标签的情况下显著改变图像的像素内容。粗略地说,这可以通过近乎无限生产新数据或者说修改数据,人为地扩大了训练集的大小。一致性正则化将数据增强应用于半监督学习,它利用了这样一种思想 : 即使对未标记的示例进行了增强,分类器也应该输出相同的类分布。更正式地说,一致性正则化强制未标记的示例x应该与Augment(x)归为一类,其中Augment©是一个随机数据增强函数,类似于随机空间平移或添加噪声。

最简单的例子,Π-Model,也叫做带有随机变化和扰动项的正则化,将下列式子加入了损失函数
在这里插入图片描述
对于未标记的x数据点,我们需要注意:Augment(x) 是一个随机变换,所以上式中Augment(x)中的两项是不相同的。该方法通过旋转、剪切、加性高斯噪声等复杂的增强过程,应用于图像分类基准。例如,“Mean teacher” 将上式中的一项替换为模型的输出,这个模型利用了模型中参数的指数移动平均。这提供了一个更稳定的目标,并在实践中发现显著改善结果。这些方法的一个缺点是,它们使用领域特定的数据增强策略. “虚拟对抗性训练。VAT(Virtual Adversarial Training)解决这个问题的方法是,计算一个加性扰动来应用于最大程度地改变输出类分布的输入。MixMatch通过对图像使用标准数据增强,利用了一致性正则化的一种形式。

3.2 Entropy Minimization/ Entropy regularization 熵最小化

在许多半监督学习方法中,一个常见的基本假设是分类器的决策边界不应该通过边缘数据分布的高密度区域。实现这一点的一种方法是要求分类器对未标记的数据输出低熵预测。

  • 这是在显式地通过简单地添加一个损失项来实现的,该损失项使Pmodel(y | x;0)未标注数据,这种形式的熵最小化与VAT相结合,得到了更强的结果 (VAT)
  • ‘Pseudo-Label 伪标签’ 通过对未标记数据的高置信度预测构建硬标签,并在标准的交叉熵损失中使用这些硬标签作为训练目标,隐式地实现了熵的最小化 (Pseudo-Label, 2013) Pseudo-Label:深度学习中一种简单有效的半监督方法
  • MixMatch还通过对未标记数据的目标分布使用“锐化”函数隐式地实现熵的最小化 (sharpen)

3.3 Traditional regularization 传统正则化

正则化是指施加的约束模型的一般方法难以记忆的训练数据,因此希望把它推广更好的看不见的数据
无处不在的正则化方法是添加一个损失项惩罚L2范数模型的参数,可以被视为执行为identity-covariance高斯之前的重量值
当使用简单的梯度下降法时,这个损失项等于指数衰减权值趋向于零。由于我们使用Adam作为梯度优化器,所以我们使用显式的“重量衰减”而不是L2损失项
最近,提出了一种混合正则化器 MixMatch,它训练输入和标签的凸合并模型。混合模型可以被看作是鼓励有严格的线性行为之间”的例子,通过要求一个或两个凸组合模型的输出输入接近的凸组合输出为每个单独的输入(43、44、18我们利用混合物在MixMatch既是规范(应用对标记点)和半监督学习方法(适用于无标号数据点)。之前,混淆已被应用于半监督学习;特别是的并发工作使用了方法的一个子集

4 MixMatch

在本节中,我们介绍了混合匹配 MixMatch,我们提出的半监督学习方法。MixMatch 是一个“整体”的方法, 它整合了前面提到的一些ideas 和一些来自主流SSL的组件。给定一个已经标签的 batch X 和同样大小未标签的batch U. MixMatch生成一批经过处理的增强标签数据X‘和一批带猜测标签的U’,然后分别计算带标签数据和未标签数据的损失项。更正式地,半监督学习的综合损失L计算如下:
在这里插入图片描述
Hpq为分布p和发布q的交叉交叉熵,T,K,α,λ 是超参数
算法流程:
在这里插入图片描述

4.1 数据增强 Data Augmentation

缓解标记数据不足的常见方法是使用数据增强。数据增强引入了一个函数Augment(x),该函数以其标签不变的方式对输入数据点x进行随机转换。重申一下,不同的增广应用将产生不同的(随机)输出。与许多SSL方法中的典型方法一样,我们对标记的和未标记的数据都使用数据增强。对于批次标记数据X中的每个xb,我们生成一个转换后的版本=Augment(xb) 。对于批次未标记数据U中的每个ub,我们生成K个增强 ub,k=Augment(ub), K 属于(1,…, K)。这些单独的扩展用于为每个ub生成一个“猜测的标签”qb,我们将在下一节中描述这个过程。
笔者笔记:对有标记的标签做1次增强,对未标记的标签做k次增强

4.2 标签猜测 Label Guessing

对于的每个未标记的训练数据U,MixMatch使用模型的预测为示例的标签生成一个“guess”。这个guess后来被用于无监督损失术语。为此,我们计算了该模型预测的分类分布在ub的所有K个增量上的平均值。在一致性正则化方法中,使用数据增强为未标记的示例获取人工目标是常见的。
在这里插入图片描述
在这里插入图片描述
笔者笔记:如此图所示,unlabeled 先经过K次随机增强,然后用 P_model(y l Augment(ub); theta) 进行预测,生成未标记标签的k个qb,然后将这这个分布取平均值,然后再进行Sharpen。对结果进行强化。在这个过程中P_model 只有一个,然后用新数据集再去训练P_model,然后再由P_model 去预测unlabeled. 这样反复迭代

4.3 Sharpening 锐化

笔者笔记:Sharpening是一个很重要的过程,这个思想相当于深度学习中的relu过程,在取平均之后不进行锐化会对结果影响很大。在生成标签猜测时,我们执行了一个额外的步骤,灵感来自于半监督学习中熵最小化的成功(在第2.2节中讨论)。在给定平均预测量的基础上,应用锐化函数减小了标签分布的熵。在实际应用中,对于锐化函数,我们使用了调整这个分类分布[5]的“温度”的常用方法,即操作
在这里插入图片描述
T是超参数。

4.4 MixUp

作为MixMatch的最后一步,我们使用了MixUp。为了在半监督学习中使用mixup,我们将它同时应用于带猜测标签的数据和没有标签的示例。与过去使用MixUp工作不同,我们将标记的示例与未标记的示例“混合”在一起,反之亦然,我们发现改进了性能。在我们的组合损失函数中,我们使用单独的损失术语来表示标记的和未标记的数据。这将导致在最初建议的表单中使用MixUp时出现问题;
相反,对于两个具有相应(one-hot)标签(cz1, pl), (x2, p2)的示例,我们定义了一个稍微修改过的混合,即计算(x’, p’) by
在这里插入图片描述

  • 5
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
MixMatch是一种半监督学习方法,它利用未标记的数据来增强有标记的数据集以提高模型的性能。如果您想使用自己的数据集训练MixMatch模型,可以按照以下步骤操作: 1. 准备数据集:您需要准备一个包含有标记和未标记样本的数据集。有标记样本应该包含输入和相应的标签,而未标记样本只包含输入。您可以根据需求选择合适的数据集,并将其存储在本地或云存储中。 2. 安装依赖项:您需要安装PyTorch和其他必要的Python库,例如NumPy、matplotlib等。 3. 下载MixMatch代码:您需要从Github上下载MixMatch的官方代码,可以通过以下链接获取:https://github.com/google-research/mixmatch 4. 配置参数:您需要打开config.py文件,并根据您的需求修改训练参数,例如批量大小、学习率、训练轮数等。您还需要指定您的数据集路径和其他相关参数。 5. 训练模型:一旦您完成了参数配置,您就可以运行train.py文件开始训练模型。在训练过程中,MixMatch会使用半监督学习方法来利用未标记的数据来增强有标记的数据集。您可以根据需要更改训练参数或停止训练。 6. 评估模型:一旦训练完成,您可以使用test.py文件来测试模型的性能。该文件将输出模型在测试数据集上的准确性和其他相关指标。 这些步骤应该可以帮助您使用自己的数据集训练MixMatch模型。请注意,MixMatch是一种高级算法,需要一定的技术知识和经验才能正确使用。建议您在开始之前仔细学习相关文献和教程,以确保正确使用该算法。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值