Retrieval-Enhanced Visual Prompt Learning for Few-shot Classification

RePrompt是为了解决小样本视觉分类和域泛化中的挑战而提出的,它通过引入检索机制来缓存知识表示,增强了提示学习的效果。方法包括检索增强的视觉提示、基于检索的适配器和检索引导训练,从而在多个数据集上实现了性能提升。
摘要由CSDN通过智能技术生成

摘要

提示学习已经成为将大型V-L模型(如CLIP)适应下游任务的流行方法。通常,提示学习依赖于固定的提示令牌或输入条件令牌,以便在完全监督下拟合少量数据。虽然这种范式可以推广到一定范围的看不见的类,但当域间隙增加时,它可能会遇到困难,例如细粒度分类和卫星图像分割。为了解决这个限制,我们提出了检索增强提示学习(RePrompt),它引入了检索机制来缓存下游任务的知识表示。我们首先从训练数据中构建一个检索数据库,或者在可用时从外部示例构建检索数据库。然后,我们将这种检索增强机制集成到简单提示学习基线的不同阶段。通过引用训练集中的相似样本,增强模型能够更好地适应样本较少的新任务。我们对 15 个视觉数据集进行了广泛的实验,包括 11 个具有小样本设置和 4 个域泛化基准的下游任务,表明 RePrompt 实现了显着提高的性能。当域间隙增加时,我们提出的方法为提示学习所面临的挑战提供了有希望的解决方案。代码和模型将可用。

动机

本文研究的背景是如何在少量数据情况下进行视觉概念识别,以及如何将大规模的V-L模型应用在该问题上。

先前的方法主要集中在从源域学习可转移的视觉表示,并通过微调技术快速适应小样本下游任务。但是,这些小样本学习算法只适用于实际上比较简单设置,例如区分5-way 1-shot分类。

最近,大规模V-L模型(如CLIP,ALIGN),在视觉表示学习方面显示出有希望的结果。将这些V-L模型应用于更具挑战性的小样本学习问题的可行性引起了广泛关注。CoOp首先提出了一种小样本设置来评估V-L模型的性能,设计了一个可学习的文本提示替换手工制作的文本提示模板,此外,受VPT的启发,许多研究也考虑在image encoder中添加额外的可学习视觉tokens,以充分微调下游任务的V-L模型。

尽管这些方法在小样本任务有显著的提升,但很难推广到极少样本的数据或非典型实例上。最近的一些检索方法通过检索知识语料库以生成附加参考以提高低资源场景中的性能,受这些方法的启发,本文考虑通过检索与当前图像在小样本训练集中的整体特征最相似的相关图像来增强提示学习。

主要贡献

探讨了引入检索系统来增强V-L模型的提示学习的可行性,重点是下游的小样本图像分类任务。

提出了RePrompt,它根据训练数据集建立了一个检索数据库,并在模型输入、中间层、输出和训练过程中实现了检索增强机制。

在各种few-shot设置下,所提出的RePrompt在11个视觉数据集上实现了state-of-the-art的性能。还在4个领域泛化基准测试中展示了卓越的性能。

方法

VLPT(视觉-语言提示微调):提出的提示学习基线

在本节中,我们提出了一个简单的双模态提示调优基线,用于小样本图像分类任务,称为视觉语言提示调优(VLPT)。VLPT同时从小样本训练数据中学习图像编码器和文本编码器的视觉和文本提示。视觉提示通过visual prompt tuning学习,而文本提示通过text prompt tuning学习。在得到类别的图像特征和文本特征后,经过调整提示的模型根据CLIP的zero-shot分类范式生成预测。

CLIP[38]是一个视觉语言预训练模型,包括两个子网络:图像编码器e_{I}和文本编码器e_{T}。这些编码器分别将视觉和文本输入编码到联合隐藏空间R^{d}中,在该空间中视觉语义和语言模态很好地对齐[54]。这里,d是嵌入的维度(例如,对于ViT[8],d=512)。对于一张图像x和一组C类别名称T={t_{1},t_{2},...,t_{C}}(例如,对于ImageNet[7]的C=1000,即有1000个类),图像编码器提取图像特征z = e_{I}(x)∈R^{d}。T中的类别名称与手工制作的文本提示模板"a photo of [CLASS]"连接作为文本描述。将这些描述进一步输入到文本编码器中,以得到文本特征f∈R^{d\times C}。x属于c类的预测概率由查询图像特征z与文本特征f的内积相似度计算。

在实践中,对每个下游任务的迁移学习的整个视觉语言模型进行微调是很麻烦的。此外,调整后的模型遭受灾难性遗忘[26,44],在预训练模型最初表现良好的新任务上表现不佳。因此,我们需要一个没有额外微调的传输策略。

Text prompt learning:与NLP中的硬手工提示工程不同,文本提示学习考虑通过学习一组提示来生成更自适应的文本特征。CoOp[54]学习一组参数P_{T}R^{d\times M}来替换手工制作的提示模板,其中M是提示长度。T中每个类别名称的单词标记被填充到P_{T}的模板中,并被视为可学习输入提示的初始值。这些提示在小样本数据上进一步微调,以生成文本特征f_{t}。经过微调的提示符P_{T}为下游任务调整文本特征的决策边界。在此过程中,从预训练的CLIP模型继承的所有参数在训练过程中被冻结。

Visual prompt learning:VPT[17]提出了一种视觉提示调优方法,该方法在冻结图像编码器e_{I}以从下游数据中提取更多可迁移的视觉特征的同时,在输入空间中引入一些可学习的参数。在 L层的图像编码的上下文中,第 i 层输出 l_{i},其中 i = 1, 2,..., L, 可以表示为:

 其中c∈R^{d}表示分类标记。Z^{i}={z1,…,zS}∈R^{d\times S}表示第i层的一组输入图像patch标记,长度为S。此外,在第i层输入序列中引入了可学习视觉提示P_{I}R^{d\times N},其中N是视觉提示的长度,如下所示:

有两种视觉提示变体,VPT Shallow和VPT Deep。在VPT Shallow方法中,将类标记与图像patch标记和视觉提示相结合作为第一层的输入,而VPT Deep在每一层中插入独立的视觉提示。

现有的视觉语言模型的提示方法往往侧重于视觉或文本提示学习,而忽略了双编码器中提示学习的灵活性。此外,由于在低数据资源设置中使用可学习参数进行完全监督训练,提示学习方法可能无法很好地推广到非典型示例上。考虑到识别任务的特点和视觉提示学习的灵活性,我们在视觉提示学习中引入了检索和关联机制,以提高性能。

本文的方法:Retrieval-enhanced prompt tuning(检索-增强提示调优)

在本节中,我们将详细阐述针对VLPT提出的检索增强提示调优RePrompt利用从训练数据集(4.1)中检索相关信息来增强提示调优。图1说明了RePrompt的总体工作流程,它包括视觉语言提示调优基线(3)和三个检索增强模块,检索增强的视觉提示(4.2)、基于检索的适配器(4.4)和检索指导训练(4.3)。

 (RePrompt的整个工作流程包括四个主要步骤。(a)通过冻结基准图像编码器将图像输入编码为查询embedding;(b)使用相同的冻结图像编码器将训练数据集中的每个图像编码为key和value嵌入对。此外,value还包括标签的单次表示。我们通过最大内积搜索检索top-K相关知识项,并将这些知识融合为条件生成可视化提示。(c)在视觉分支的J层输入中引入检索增强的视觉提示,其余提示与基线VLPT的提示一致。(d)将经过提示调整的CLIP预测和基于检索的适配器预测线性组合得到最终输出。)

Retrieval module(检索模块)

为了充分利用现有的预训练模型,我们利用冻结图像编码器e_{R}来得到查询图像特征z_{q}R^{d}。注意,e_{R}是通过视觉语言预训练的冻结图像编码器模型。我们使用检索数据库作为一个健壮的token存储库,随着下游任务的变化而变化。请参阅补充材料,了解详细的可视化,以证明与监督学习模型相比,使用这些模型作为编码器的必要性。我们的检索模块包括两个步骤:(1)建立数据库,(2)检索。

Retrieval database(检索数据库):数据库由从小样本训练数据集D中提取的特征构建而成,检索数据库中有|D|个键值对(k_{i},v_{i})。键k_{i} = e_{R} (x_{i})∈R^{d}是由冻结e_{R}提取的训练图像表示。有两种类型的值v_{i},标记y∈N *和图像表示z_{i} =e_{R} (x_{i})∈R^{d}

Effective and efficient retrieval(高效检索):如图1所示,检索数据库采用矩阵D∈R^{|D|\times d}作为样例的快速近似k-NN。为查询图像x_{q}生成查询向量z_{q} = e_{R} (x_{q})。我们利用查询向量z_{q}来检索它的近似k近邻,它们具有相应的表示z_{1}, z_{2},...,z_{k}除以矩阵D使用余弦相似度。在检索过程中,我们选择FAISS(相似向量检索库)[20]来高效地查询数据库。

Retrieval-enhanced visual prompt(检索增强的视觉提示)

该方法旨在利用检索数据库进行类比学习,以增强视觉提示学习。我们设计使用检索结果有条件地生成视觉提示,称为检索增强的视觉提示。检索增强视觉提示和图像patch标记的交叉注意响应图可视化如图2所示。RePrompt表现出更强的自注意反应,如有趣区域扩展和更高的注意值。

 (检索增强视觉提示与图像patch标记间注意反应映射的可视化。平均自注意图来自最后的视觉transformer层。)

检索模块从原始图像x_{q}中获取查询向量z_{q},并在矩阵D上执行查找,返回最相似的前k_{re}个候选者。将数据库中相应的表示z_{1}z_{2},...,z_{k_{re}}合并到图像编码器中,以增强视觉提示。我们基于k_{re}相邻表示的相似性直观地将其聚合为额外的融合向量z_{f}R^{d}[5],如下所示:

 查询向量z_{q},融合向量z_{f}和检索向量z_{1}, z_{2},…,z_{k_{re}}被连接起来形成输入\hat{i}= [z_{q}, z_{f}, z_{1},…,z_{k_{re}}]∈R^{d\times (k_{re+2})},生成检索增强的视觉提示f_{p}(\hat{i})R^{d\times (k_{re+2})}

如图3顶部所示,视觉提示学习者随机初始化J个视觉提示P_{I}^{1},…,P_{I}^{J}。然后利用J个检索增强卷积(REConv)块来处理J个输入\hat{i}_{1},...,\hat{i}_{J}(将\hat{i}复制J次)以生成动态提示。然后将这些提示与随机初始化的J个视觉提示组合,并插入到视觉分支的前J层的输入序列中:

 (视觉提示学习者和REConv概述。视觉提示学习器包括REConv,通过学习检索到的结果来生成动态视觉提示。)

12−J层处理剩余的输入可学习提示,并且学习这些提示退化为等式(2)。我们在补充材料中进一步讨论了插入深度的选择。

为了有效地融合检索到的表示,我们提出了一种检索增强卷积块(REConv)。如图3底部所示,REConv块由三个卷积层组成:两个1×1卷积,分别降低和缩放信道维数;一个3×3卷积,位于两个1 x 1卷积的中间。在这些卷积层之前,我们将视觉提示的1维token序列结构重塑为2维矩阵结构。REConv块并行处理\hat{i}以生成动态提示,其可以公式化为:

 其中LN是层归一化,β是缩放输出的超参数。关于处理检索到的表示的更多设计讨论可以在补充材料中找到。

Retrieval-based adapter(基于检索的适配器)

此外,我们使用了一个可微的kNN分类器作为基于检索的适配器。该适配器与检索增强的视觉提示一起进行训练,以为下游任务生成更自适应的预测概率。具体而言,给定查询实例x_{q},使用图像编码器提取查询向量 \hat{z}_{q},检索增强的视觉提示并在矩阵 D 上查找。检索过程返回|D|-具有相应内积相似性的最近邻居。然后,我们对每个标签yi在检索目标中所有出现的概率质量进行汇总。假设p_{kNN}表示查询实例x_{q}被预测的概率,则p_{kNN}(y|x_{q})通过p_{kNN}的加权和重新表述如下

 此外,通过用提示调优CLIP的预测p_{P}(y|x_{q})和插值p_{kNN}(y|x_{q})来重新表述p(y|x_{q}),以获得标签的最终概率:

 先前的半参数小样本分类工作Tip-Adapter-F[51]与我们提出的范式相似。主要区别在于,我们的查询特征是从带有检索增强提示符的CLIP图像编码器分支派生出来的。这使得查询图像特征更适合下游数据集,从而获得比Tip-Adapter更好的性能。

Retrieval-guiding training(检索-引导训练)

k近邻(kNN)主要关注查询实例[2]的近似邻域。利用kNN分类结果作为先验知识来指导RePrompt在训练过程中关注困难的例子很直观。hard samples通常是指置信度较低的非典型样本。为了计算局部概率分布,我们限制检索到的邻居集 K ⊆ D 中的样本数,其中k_{rc} ≠ |K|,如下:

 概率p_{kNN}对应于将查询实例x_{q}分类为特定类别的置信度。与Focal Loss[31]类似,p_{kNN}的负对数似然值被用作调整因子p_{t}=−log(p_{kNN})。调整因子通过调整由kNN区分的伪正确样本或伪误差样本的相对损失来重新加权交叉熵损失L_{CE}。最后的损失公式为:

 其中γ是比例因子。设|K| = C × n, n∈N^{*}。[5]在NLP的小样本学习任务上,采用了类似的损失来增强模型的性能。在小样本实验中,n可以为1、2、4、8、16以适应各种小样本设置。

实验

小样本学习

数据集:在11个图像分类数据集上评估了RePrompt。为了评估RePrompt在小样本学习设置下的性能,我们构建了1、2、4、8、16-shot的全类别训练集和整个测试集。

基线:我们将我们提出的RePrompt与四种现有的基于提示的方法进行了比较,即CoOP[54]、VPT-Deep[17]、Tip-Adapter-F[51]和VLPT。(1) CoOP学习与[CLASS]连接的上下文提示符作为文本编码器的输入。(2) VPT-Deep在视觉编码器的每一层transformer中插入一些可学习的视觉提示。(3)视觉语言提示调优(VLPT)在CoOP和VPT-Deep下共同优化不同模态编码器的提示。(4) Tip-AdapterF利用少量训练数据构建基于缓存模型的适配器。它可以被看作只是将kNN分类器的预测与CLIP分类结果内插。

训练细节:由于提示微调最初是在NLP中引入的一个概念,我们很自然地将这个概念扩展到统一的视觉和语言模型,通过在ViT模型上实现它,比如ViT- B /16,它由12个像文本编码器一样的transformer层组成。我们遵循CLIP的数据预处理协议,并在训练时冻结从预训练模型中继承的参数。我们设置检索模块的超参数如下:kre为 7,有9个随机随机初始化的提示符。n = 8,导致检索引导损失中检索C × 8个对象。配备检索增强提示的视觉transformer层为前7层。

主要实验是使用一组超参数进行训练和验证的。关于所有数据集的更多训练细节和检索结果可视化,请参阅附录A和C。

结果:基线方法和我们提出的RePrompt对小样本图像分类的性能如图4所示。我们的RePrompt在1/2/4/8/16-shot设置上的平均准确率始终优于之前的最佳基线+2.32/+1.57/+1.20/+0.89/+1.48(%)。基于VLPT, RePrompt实现了显著的性能提升,特别是在DTD上+5.87%,在Stanford Cars上+9.04%。RePrompt在具有丰富类别的挑战性数据集上大大超过了CoOP和VPT-Deep的性能,例如具有1000个类的ImageNet和具有397个类的SUN397。我们还观察到Reprompt在Oxford Pets和Food101上取得的改进较少。这可能是由这些数据集中的噪声数据引起的[54,4,50]。

 (在小样本设置下的11个数据集的主要结果。我们报告了1、2、4、8、16-shot的平均准确率(%)。提出的RePrompt在大多数下游识别数据集上实现了显着的性能改进。)

域泛化

预训练的视觉语言模型(如CLIP)具有较强的领域泛化能力。我们评估了提出的RePrompt在out- distribution (OOD)数据集上的鲁棒性。

数据集:我们遵循CoOp[54],使用五个数据集,即ImageNet、ImageNet V2[40]、ImageNetSketch[46]、ImageNet- A[14]和ImageNet- R[13],来评估RePrompt对out of distribution (OOD)数据的泛化能力。根据协议,我们在16-shot设置的ImageNet(源数据集)上训练模型,并在其他域移位数据集(目标数据集)上对其进行评估。因此,我们使用ImageNet 16shot实验的检索数据库作为目标数据集的检索数据库。

结果:表1总结了OOD实验结果,其中我们报告了源数据集和目标数据集的准确性。RePrompt在ImageNet V2和ImageNet- R上实现了最佳结果,并在Image-Sketch上展示了与UPT[50]相当的性能。这些结果表明Reprompt是一种合理且更健壮的提示调优方法。

 (领域泛化设置下的主要结果。我们报告了三次运行中16-shot的平均准确率(%)。)

消融实验

Component ablation:我们调查了RePrompt的有效性,并在表2中报告了结果。通过逐步引入检索增强模块(+Retrievalguiding training (+Rg training loss)),+检索增强视觉提示(+Re visual prompt),+基于检索的适配器(+Rb adapter)),RePrompt的平均准确率稳步提高。

 

检索参数消融:在ImgeNet上进行了一个实验,以验证检索的总体影响。表3显示kre = 7的平均性能稍好一些。检索增强提示符主要受REConv卷积核大小的影响。1-shot设置下的结果表明,检索增强提示可以学习类似类的共性。如表4所示,在16、8、4、2、1-shot设置中,n被正确设置为8,4,4,2,1。合理程度的KNN-引导训练在不同的小样本设置中是不同的。进一步的消融实验见附录B。

 

关于ReConv的讨论:我们讨论了用于融合令牌和查询的网络结构对ReConv模型最终结果的影响。为此,我们在ImgeNet上进行了额外的实验,其中我们替换了ReConv中的卷积层。ReRNN是用LSTM[15,41]层代替卷积层的一种变体。ReMLP是用多层感知器层代替卷积层的变体。此外,ReConv(zq)是一个仅将zq作为输入的变体。这些变体的实验设置与ReConv的实验设置一致,并且两个模型中的参数数量被控制为相似。

如表5所示,在所有shot 设置中,ReConv的表现都略好于其他工具。这些结果表明,用于融合令牌和查询的特定网络结构可能不会对整体模型的性能产生重大影响。网络结构的选择可能会影响模型的计算成本和优化难度,特别是当任务的复杂性增加时。

训练和推理时间:所有现有方法的性能比较如表6所示,包括在ImageNet上对16张照片分类的训练时间和推理速度。“Wo Retrieval”是一个与RePrompt具有相同数量可学习参数的可比模型。结果表明,与“Wo Retrieval”相比,RePrompt的性能提高了2.71%。此外,我们将训练时间减少到仅20个epoch,这与COOP相反。考虑到准确率提高和推理速度之间的权衡,我们认为对于小样本图像分类,推理速度是可以接受的。

 

总结

在本文中,我们提出了一个新的基于检索的框架RePrompt,它提高了提示学习方法在小样本分类任务中的性能。我们提出的方法由基于检索的适配器和检索增强的提示组成,以增加简单提示学习基线的不同阶段。大量的实验结果表明,该方法在小样本学习和域泛化方面均优于其他提示学习方法。我们希望我们的工作能够在以下值得关注的方向上激发进一步的研究:1)将提示学习扩展到其他下游任务,如分割2)探索检索以解决其他问题,如长尾分类。

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值