计算机视觉中的半监督学习

2020-07-19 22:05:39

作者:Amit Chaudhary

编译:ronghuaiyang

导读

图解半监督的各种方法的关键思想。

计算机视觉的半监督学习方法在过去几年得到了快速发展。目前最先进的方法是在结构和损失函数方面对之前的工作进行了简化,以及引入了通过混合不同方案的混合方法。

在这篇文章中,我会通过图解的方式解释最近的半监督学习方法的关键思想。

1、自训练

在该半监督公式中,对有标签数据进行训练,并对没有标签的数据进行伪标签预测。然后对模型同时进行 ground truth 标签和伪标签的训练。

计算机视觉中的半监督学习

 

a. 伪标签

Dong-Hyun Lee[1]在 2013 年提出了一个非常简单有效的公式 —— 伪标签。

这个想法是在一批有标签和没有标签的图像上同时训练一个模型。在使用交叉熵损失的情况下,以普通的监督的方式对有标签图像进行训练。利用同一模型对一批没有标签的图像进行预测,并使用置信度最大的类作为伪标签。然后,通过比较模型预测和伪标签对没有标签的图像计算交叉熵损失。

计算机视觉中的半监督学习

 

总的 loss 是有标签和没有标签的 loss 的加权和。

计算机视觉中的半监督学习

 

为了确保模型已经从有标签的数据中学到了足够的知识,在最初的 100 个 epoch 中,αt 被设置为 0。然后逐渐增加到 600 个 epochs,然后保持不变。

计算机视觉中的半监督学习

 

b. Noisy Student

Xie 等[2]在 2019 年提出了一种受知识蒸馏启发的半监督方法“Noisy Student”。

关键的想法是训练两种不同的模型,即“Teacher”和“Student”。Teacher 模型首先对有标签的图像进行训练,然后对没有标签的图像进行伪标签推断。这些伪标签可以是软标签,也可以通过置信度最大的类别转换为硬标签。然后,将有标签和没有标签的图像组合在一起,并根据这些组合的数据训练一个 Student 模型。使用 RandAugment 进行图像增强作为输入噪声的一种形式。此外,模型噪声,如 Dropout 和随机深度也用到了 Student 模型结构中。

计算机视觉中的半监督学习

 

一旦学生模型被训练好了,它就变成了新的老师,这个过程被重复三次。

2、一致性正则化

这种模式使用的理念是,即使在添加了噪声之后,对未标记图像的模型预测也应该保持不变。我们可以使用输入噪声,如图像增强和高斯噪声。噪声也可以通过使用 Dropout 引入到结构中。

计算机视觉中的半监督学习

 

a. π-model

该模型由Laine 等[3]在 ICLR 2017 年的一篇会议论文中提出。

关键思想是为标记数据和未标记数据创建两个随机的图像增强。然后,使用带有 dropout 的模型对两幅图像的标签进行预测。这两个预测的平方差被用作一致性损失。对于标记了的图像,我们也同时计算交叉熵损失。总损失是这两个损失项的加权和。权重 w(t)用于决定一致性损失在总损失中所占的比重。

计算机视觉中的半监督学习

 

b. Temporal Ensembling

该方法也是由Laine 等[4]在同一篇论文中提出的。它通过利用预测的指数移动平均(EMA)来修正模型。

关键思想是对过去的预测使用指数移动平均作为一个观测值。为了获得另一个观测值,我们像往常一样对图像进行增强,并使用带有 dropout 的模型来预测标签。采用当前预测和 EMA 预测的平方差作为一致性损失。对于标记了的图像,我们也计算交叉熵损失。最终损失是这两个损失项的加权和。权重 w(t)用于决定稠度损失在总损失中所占的比重。

计算机视觉中的半监督学习

 

c. Mean Teacher

该方法由Tarvainen 等[5]提出。泛化的方法类似于 Temporal Ensembling,但它对模型参数使用指数移动平均(EMA),而不是预测值。

关键思想是有两种模型,称为“Student”和“Teacher”。Student 模型是有 dropout 的常规模型。教师模型与学生模型具有相同的结构,但其权重是使用学生模型权重的指数移动平均值来设置的。对于已标记或未标记的图像,我们创建图像的两个随机增强的版本。然后,利用学生模型预测第一张图像的标签分布。利用教师模型对第二幅增强图像的标签分布进行预测。这两个预测的平方差被用作一致性损失。对于标记了的图像,我们也计算交叉熵损失。最终损失是这两个损失项的加权和。权重 w(t)用于决定稠度损失在总损失中所占的比重。

计算机视觉中的半监督学习

 

d. Virtual Adversarial Training

该方法由Miyato 等[6]提出。利用对抗性攻击的概念进行一致性正则化。

关键的想法是生成一个图像的对抗性变换,这将改变模型的预测。为此,首先,拍摄一幅图像并创建它的对抗变体,使原始图像和对抗图像的模型输出之间的 KL 散度最大化。

然后按照前面的方法进行。我们将带标签/不带标签的图像作为第一个观测,并将在前面步骤中生成的与之对抗的样本作为第二个观测。然后,用同一模型对两幅图像的标签分布进行预测。这两个预测的 KL 散度被用作一致性损失。对于标记了的图像,我们也计算交叉熵损失。最终损失是这两个损失项的加权和。采用加权偏置模型来确定一致性损失在整体损失中所占的比重。

计算机视觉中的半监督学习

 

e. Unsupervised Data Augmentation

该方法由Xie 等[7]提出,适用于图像和文本。在这里,我们将在图像的上下文中理解该方法。

关键思想是使用自动增强创建一个增强版本的无标签图像。然后用同一模型对两幅图像的标签进行预测。这两个预测的 KL 散度被用作一致性损失。对于有标记的图像,我们只计算交叉熵损失,不计算一致性损失。最终的损失是这两个损失项的加权和。权重 w(t)用于决定稠度损失在总损失中所占的比重。

计算机视觉中的半监督学习

 

3、混合方法

这个范例结合了来自过去的工作的想法,例如自我训练和一致性正则化,以及用于提高性能的其他组件。

a. MixMatch

这种整体方法是由Berthelot 等[8]提出的。

为了理解这个方法,让我们看一看每个步骤。

i. 对于标记了的图像,我们创建一个增强图像。对于未标记的图像,我们创建 K 个增强图像,并对所有的 K 个图像进行模型预测。然后,对预测进行平均以及温度缩放得到最终的伪标签。这个伪标签将用于所有 k 个增强。

计算机视觉中的半监督学习

 

ii. 将增强的标记了的图像和未标记图像进行合并,并对整组图像进行打乱。然后取该组的前 N 幅图像为 W~L~,其余 M 幅图像为 W~U~。

计算机视觉中的半监督学习

 

iii. 现在,在增强了的有标签的 batch 和 W~L~之间进行 Mixup。同样,对 M 个增强过的未标记组和 W~U~中的图像和进行 mixup。因此,我们得到了最终的有标签组和无标签组。

计算机视觉中的半监督学习

 

iv. 现在,对于有标签的组,我们使用 ground truth 混合标签进行模型预测并计算交叉熵损失。同样,对于没有标签的组,我们计算模型预测和计算混合伪标签的均方误差(MSE)损失。对这两项取加权和,用 λ 加权 MSE 损失。、

计算机视觉中的半监督学习

 

b. FixMatch

该方法由Sohn 等[9]提出,结合了伪标签和一致性正则化,极大地简化了整个方法。它在广泛的基准测试中得到了最先进的结果。

如我们所见,我们在有标签图像上使用交叉熵损失训练一个监督模型。对于每一幅未标记的图像,分别采用弱增强和强增强方法得到两幅图像。弱增强的图像被传递给我们的模型,我们得到预测。把置信度最大的类的概率与阈值进行比较。如果它高于阈值,那么我们将这个类作为标签,即伪标签。然后,将强增强后的图像通过模型进行分类预测。该预测方法与基于交叉熵损失的伪标签的方法进行了比较。把两种损失合并来优化模型。

计算机视觉中的半监督学习

 

不同方法的对比

下面是对上述所有方法之间差异的一个高层次的总结。

计算机视觉中的半监督学习

 

在数据集上的评估

为了评估这些半监督方法的性能,通常使用以下数据集。作者通过仅使用一小部分(例如:(40/250/4000/10000 个样本),其余的作为未标记的数据集。

计算机视觉中的半监督学习

 

结论

我们得到了计算机视觉半监督方法这些年是如何发展的概述。这是一个非常重要的研究方向,可以对该行业产生直接影响。

英文原文:https://amitness.com/2020/07/semi-supervised-learning/

### 计算机视觉中的半监督学习 #### 方法 半监督学习结合了有标签和无标签的数据来改进模型性能。在计算机视觉领域,这种方法特别有用,因为获取大量标注图像的成本很高。一种常用的方法是通过一致性正则化,在不同增强版本的同一张图片上强制模型给出相似的结果[^2]。 另一种有效的方式是在特征空间中利用聚类算法找到自然存在的结构,并假设来自相同簇的数据应该具有相同的标签。这可以通过引入额外损失项实现,该损失项鼓励属于同一个簇内的样本拥有更接近的表示形式[^3]。 对于卷积神经网络而言,还可以采用伪标签策略——即先用少量带标签样本来预训练一个基础分类器;接着使用此初步得到的模型去预测未标记数据集上的类别分布情况并挑选置信度较高的作为新增加的真实标签加入到已知集合里面继续迭代优化整个过程直到收敛为止[^1]。 ```python import torch.nn as nn from torchvision import models class SemiSupervisedModel(nn.Module): def __init__(self, num_classes=10): super(SemiSupervisedModel, self).__init__() resnet = models.resnet50(pretrained=True) self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.classifier = nn.Linear(2048, num_classes) def forward(self, x_labeled=None, x_unlabeled=None): if x_labeled is not None: features_labeled = self.backbone(x_labeled).squeeze() logits_labeled = self.classifier(features_labeled) if x_unlabeled is not None: with torch.no_grad(): features_unlabeled_1 = self.backbone(augment(x_unlabeled)).squeeze() features_unlabeled_2 = self.backbone(augment(x_unlabeled)).squeeze() return (logits_labeled,) if x_labeled is not None else () ``` #### 应用场景 在实际应用方面,半监督学习可以应用于医疗影像分析、自动驾驶车辆感知系统以及工业缺陷检测等多个领域。例如,在医学成像诊断任务中,由于高质量的手动分割非常耗时费力,因此仅能获得有限数量的专业医生提供的金标准案例用于训练深度学习模型。此时如果能够充分利用那些未经标注但同样重要的病例资料,则有助于提升最终系统的泛化能力和准确性。 另外,在智能交通管理平台建设过程中也经常遇到类似的问题:摄像头采集回来的道路状况视频流虽然海量存在却难以全部人工审核确认每帧画面的具体语义信息。借助于半监督框架下的先进算法就可以自动识别出大部分正常行驶状态而只需针对少数异常情况进行进一步核查处理即可满足日常运营需求。 #### 最新进展 近年来,随着自监督表征学习的发展,研究人员提出了更多创新性的解决方案来解决传统半监督方法中存在的挑战。比如SimCLR提出的对比学习机制能够在不依赖任何显式的类别标签的情况下有效地挖掘大规模原始多模态信号内部蕴含着丰富的上下文关联模式从而为下游特定目标任务提供更好的初始化权重参数设置方案。 此外还有研究探索如何更好地融合弱监督信息(如边界框位置提示)、部分监督设定下(只有少部分实例被赋予完整描述)以及其他类型的辅助知识源(包括但不限于文本说明文档、音频解说词等),使得即使面对极端稀缺甚至完全缺失指导性线索的情形也能构建起具有一定实用价值的目标探测与跟踪能力。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值