Augmenting Few-Shot Learning With Supervised Contrastive Learning

https://github.com/taemin-lee/SPTA

ABSTRACT

小样本学习处理少量数据,这会导致传统交叉熵损失的性能不足。我们为小样本学习场景提出了一种预训练方法。也就是说,考虑到特征提取器的质量是少样本学习的关键因素,我们使用对比学习技术来增强特征提取器。据报道,将监督对比学习应用于转导小样本训练管道中的基类训练可以提高结果,优于 Mini-ImageNet 和 CUB 上的最新方法。此外,我们的实验表明,当存在域转移退化时,需要更大的数据集来保持小样本分类精度,并且如果应用我们的方法,就不需要大数据集了。在资源受限的环境中,精度增益可以转化为运行时间减少 3.87 倍。

索引术语 Few-shot learning、contrastive learning、information maximization.。

五、结论

我们建议在少样本学习的第一阶段将监督对比学习应用于预训练。 特征提取器使用监督对比损失进行训练,然后进行微调,而分类器使用 TIM 损失进行适应。 我们报告说我们的方法是数据高效的(即,适用于小型数据集),同时在大型数据集上保持具有竞争力的准确性性能。 我们的实验表明,我们在 Mini-ImageNet 和 CUB 数据集上取得了新的最先进的结果

II. RELATED WORK

B. CONTRASTIVE LEARNING

对比学习[2]、[10]、[13]、[36]、[41]、[48]是一种受噪声对比估计[9]、[28]或N-pair loss启发的自监督学习方法[ 38]。 [48] 提出在 CNN 提取图像的特征向量后,使用非参数 softmax 分类器来增加 128 维单位球体上的实例级距离。 [13] 改进了对比预测编码,以实现具有特征提取器和上下文网络的预训练阶段,以预测图像块的空间位置。 Deep InfoMax [10] 提出了一种训练编码器的方法,该方法可以最大化输入数据和输出特征之间的互信息(the mutual information)。 [41]旨在通过将同一场景的视图拉在一起并将不同场景的视图分开来最大化同一图像的不同视图之间的互信息。时间对比网络 (TCN) [36] 提出通过将锚点和正图像拉在一起同时将负图像分开来从多视图视频中学习。 SimCLR [2] 实现了两个数据增强方式和一个可学习的非线性变换,通过从同一图像中提取特征嵌入来训练具有大批量的编码器。监督对比学习 [14] 是传统对比学习的扩展,已针对监督分类进行了修改。

III. METHODOLOGY

A. PROBLEM DEFINITION

在转导设置(the transductive setting)中,模型可以一次访问包括查询集(即 N ×K +N ×T 个样本)在内的整个数据集,而不是在传统的归纳设置(the traditional inductive setting)中一个接一个(即,每个 N ×K +1 个样本)。N ×K +N ×T 代表支持集加查询集。

B. EXAMINING A FEW-SHOT LEARNING METHOD

在这项研究中,我们研究了转导信息最大化 (TIM,the transductive information maximization) 少样本学习算法 [1]。 首先,特征提取器将输入图像转换为嵌入特征。 TIM 通过更新 soft-classifier的可训练权重来最大化查询图像特征和查询标签之间修改后的互信息(the modified mutual information)。 为了最大化信息,TIM 最小化条件熵(the conditional entropy)并最大化边际熵( the marginal entropy )。 最小化条件熵旨在通过对集群假设建模来做出自信的预测,这意味着分类标准不应出现在未标记特征的密集区域中。 最大化边际熵推动标签的边际分布是均匀的,这试图避免只输出一个类的解决方案。 与传统的交叉熵损失一起,TIM 损失定义如下:
在这里插入图片描述其中 pin 是给定特征的标签上的后验分布,而 pn 是查询标签上的边缘分布。

给定损失目标,提出了两种优化方法[1]。 一种是传统的梯度下降(TIM-GD)方法,它通过小批量采样来最小化损失目标。 虽然 TIM-GD 显示出最好的结果,但它比归纳方法慢了两个数量级,这就引出了第二种方法,称为交替方向法 (TIM-ADM),它将问题分成两个更易于管理的子问题并迭代优化 . 与 TIM-GD 相比,TIM-ADM 显示出具有竞争力的结果,同时速度快了一个数量级。 在这两种方法中,都需要足够多的迭代才能收敛到最佳结果。 TIM-GD 和 TIM-ADM 的迭代次数的典型值分别为 1,000 和 150。

C. AUGMENTING FEW-SHOT LEARNING WITH SUPERVISED CONTRASTIVE LEARNING

特征提取器的质量是改进小样本学习算法的主要挑战之一,因为它与特征嵌入的质量直接相关。 监督对比学习[14]是自我监督表示学习的扩展; 它具有类似的两阶段训练过程,如图 1 所示。第一阶段准备输入图像的两个副本并对其进行预处理。 然后,编码器网络将图像转换为归一化嵌入,另外一个投影网络将嵌入转换为低维嵌入。 通过吸引具有相同类别标签或来自相同复制图像的正样本,并通过排斥负样本,在低维嵌入上计算监督对比损失。 有监督的对比损失定义如下:
在这里插入图片描述
其中 zl 是低维嵌入,τ 是温度参数,A(i) ≡ I \ {i},i 是锚索引,P(i) ≡ {p ∈ A(i) : ̄yp = ̄yi} 是除锚点外所有正例的索引集。 嵌入空间上的内积运算衡量两个特征嵌入之间的相似性。 当anchor的特征嵌入与所有正的特征嵌入相似并且与所有负的特征嵌入不同时,损失被最小化。 损失是从传统的 SimCLR [2] 自监督对比损失推广而来的,以支持多视图批次中的多个正例。

值得注意的是,提出了在少样本学习的第一阶段进行监督对比学习,而不是使用基类和交叉熵进行常规训练。训练过程的第二步是丢弃投影网络并使用新的分类器微调编码器网络。正如表示学习所暗示的那样,编码器网络在训练过程的第一步中变得具有辨别力;因此,微调过程相对较短,并且以较低的学习率为指导。请注意,我们使用基类和交叉熵对特征提取器进行了微调,这是在监督对比学习的第一阶段预训练的。监督对比学习中的微调过程是可选的;我们可以跳过不涉及特征提取器的过程,因为我们最后只使用特征提取器。当我们遵循线性评估协议时,我们保持特征提取器不变,这意味着我们跳过了微调过程。我们选择使用微调方法,因为它比没有微调产生更好的结果,如表 1 所示。

在我们的实验中,我们在第一个小样本训练阶段添加了一种监督对比学习方法作为额外的预训练步骤。此外,我们使用基类数据集微调了具有交叉熵损失的特征提取器。在这里插入图片描述图 1. 为少样本学习提出的预训练方法包括一个多阶段的训练过程。 监督对比学习的第一阶段使用监督对比损失和投影头与基础数据集来学习视觉表示。 监督对比学习的第二阶段使用传统的交叉熵损失和基础数据集来微调特征提取器。 这种两阶段的监督对比学习包括小样本学习的第一阶段。 少样本学习的第二阶段使用 TIM [1] 损失和使用新数据集固定的特征提取器以执行 TIM 适应。 如果监督微调成为标准监督训练,并跳过监督对比预训练,那么整个管道与基线方法 [1] 中的相同。

IV. EXPERIMENTS

C. IMPLEMENTATION DETAILS

在实施 [1]、[46] 之后,我们主要研究了三种不同的骨干网络模型,即 ResNet-18、MobileNet 和 WRN28-10。我们在表 2 中进一步检查了另外两个 ResNet 变体,即 ResNet-10 和 ResNet-12,以进行公平比较。请注意,ResNet 后面的数字表示网络的深度。尽管如此,我们按照 [1]、[59] 的约定在一组中报告 ResNet 变体。我们主要研究了 TIM 算法的交替方向法 (ADM) 版本,它比梯度下降 (GD) 版本更快。我们在 TIM 中添加了原型估计技术 [26]、[59]。这进一步提高了 1-shot 分类精度。我们在 N = 3 和 M = 20 的监督对比学习 的预处理阶段使用 PyTorch [30] 复现的 RandAugment 来实现修改后的堆叠式 RandAugment。在预训练时,我们使用 1000 个训练 epoch 进行监督对比学习,然后进行 5 个微调 epoch。我们的方法在组合方法(即监督对比学习、原型估计和 TIM-ADM)的名称之后被称为 SPTA。

I. LIMITATIONS

在域转移实验中,我们观察到通过我们的方法训练的特征提取器在域转移设置下(即表 5 中的最后两行)没有提高少样本学习的准确性。 这意味着我们的方法高度依赖于基础数据集,因为它使用基础数据集消耗了大量的 epoch(即 1,000 个用于监督对比学习的 epoch)。
因此,我们建议我们的方法的应用仅限于不存在域转移的场景。 没有域转移设置鼓励在实际实现中使用更小的基础数据集。

监督对比学习的成本是另一个限制。 建议批量大小大于基础数据集中的类数,以便在单个多视图批次中提供足够数量的阳性结果。 这意味着需要许多图形处理单元 (GPU) 来实施和阻碍广泛的实验。 具体来说,我们使用两个 GTX 2080 Ti GPU 到六个 P100 GPU 来支持一次运行适当批量大小的监督对比学习。 因此,我们强调需要有足够计算能力的服务器来实现预训练阶段。 请注意,一旦完成预训练阶段和微调,剩余的算法就可以在资源受限的环境中实现。

总之,这两个限制都表明我们的方法在数据集大小方面的可扩展性不足,因此,它在小规模应用中是有效的(例如,few-shot learning)

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值