标题:《PromptKD: Unsupervised Prompt Distillation for Vision-Language Models》
主页:zhengli97.github.io/PromptKD
论文:https://arxiv.org/pdf/2403.02781.pdf
代码:https://github.com/zhengli97/PromptKD (已开源,欢迎star~)
原文:https://zhuanlan.zhihu.com/p/684269963@Zheng Li
TL;DR
PromptKD
是一个简单有效的 prompt-based 的 VLM 蒸馏新方法,在 prompt learning 主流的 11 个 benchmark 数据集上大幅领先,达到了 SOTA。
前情回顾
What’s the VLM?
Vision-Language Models, VLMs, 即视觉语言模型,顾名思义就是由视觉(Vision)部分和语言(Language)部分所组成。说到 VLMs,最为经典和被人熟知的当属 OpenAI 开源的 CLIP 模型:
如上图所示,CLIP 由两个并行分支组成,即图像分支和文本分支。
其中, 文本分支主要由 transformer 块构成,当要进行 cls_num 个类的分类任务时,会取每个类别对应的名称,如 “plane”, “car”, “dog”,与模板 a photo of a
进行组合,作为提示喂给文本编码器,得到大小为 [cls_num, feat_dim] 的文本特征。
而图像分支的核心就是对输入的图像提取图像特征,其通常为 ResNet 或者 ViT 等骨干网络。图像经过图像编码器之后得到图像特征,其大小为 [batch_size, feat_dim]。
最后,我们只需要将上述两个编码后的特征进行相乘就得到了预测的 logits 分布。
需要注意的是,此处 CLIP 有两个明确的特性,是本工作的基础:
- CLIP 可以进行 zero-shot 分类,即对未见过的类别进行识别,并保持很高的性能,这在传统的 CNN-based 或者纯 ViT-based 的模型是很难做到的。
- 对于已知的类别,CLIP 的文本分支仅需一次 forward 就可以得到对应的文本特征参与到后续的分类任务中。
What’s the Prompt Learning?
如上展示,CLIP 在文本分支中直接采用模板 a photo of a {class_name} 作为提示,然而这样的描述过于宽泛,显然不是最优的。例如,对于下图2(b)的花,手工设计的 a flower photo a {class}
要描述的更加精确,其产生的结果自然就会更好:
其中,蓝色方块代表手动设计的 prompt,绿色方块代表网络学习得到的 learnable prompt。显然,绿色方块的精度超越了蓝色方块给定的 prompt。问题来了:
- 第一,固定模板的 prompt 不是最优的。
- 第二,针对性的手工设计不仅费时费力,且无法泛化。
于是,提示学习(Prompt Learning)应时而生,其思想是提出将 prompt 变成一种 learnable 即可学习的方式,通过优化的方法让 prompt 在下游数据集上学习适用的表征,来替代手工设计的 prompt,例如 a photo of [CLASS]
这种单一的模板。
这样做的优势是,可以在少量数据的情况下,仅通过引入一少部分的可学习参数(即 learnable prompt),就可以将原始的 CLIP 快速适用到下游的任务或数据上,同时在性能上比全参数微调的结果会更好。
在正式介绍之前,有必要先了解下实验指标,此处主要涉及三个:base acc,novel acc 和 harmonic mean。
以 imagenet-1k 数据集为例,会取 1000 类中的前 500 类作为 base class,后 500 类作为 novel class。模型在base class上训练,完成后在base class和novel class上共同测试 acc 性能。由于 novel class 与 base class 数据类别不重复,所以 novel acc 可以有效反应模型泛化性能。
harmonic mean 指标则是对 base acc 和 novel acc 的综合反映,与 F1-score 类似,其数学表达式为:
h a r m o n i c m e a n = ( 2 × b a s e a c c × n o v e l a c c ) ( b a s e a c c + n o v e l a c c ) harmonic mean = \frac{(2 \times base_acc \times novel_acc)}{(base_acc + novel_acc)} harmonicmean=(baseacc+novelacc)(2×baseacc×novelacc)
总体的 harmonic mean 值越高,模型综合性能越好。
Motivation
OK,通过上面的快速入门,我们了解到了 prompt learning 的核心作用是:在保持原始 CLIP 参数不变的前提下,通过引入小部分 learnable prompt 参数,来将大的原始的经过预训练的 CLIP 模型快速适用到下游任务或数据上,从而提升 CLIP 模型在下游任务的性能,同时保持 CLIP 模型 zero-shot 能力。
除去一直发展至今的各种设计 prompt 形式的工作,现如今最前沿的 prompt learning 方法主要还可以分为另外两类:
引入额外数据/信息
这一类工作核心就是通过引入额外的数据或信息,做法包括但不限于,
- 通过 LLM 来生成 {class_name} 相关的语句,获得额外的有关 {class_name} 的特征或者更多的描 caption;
- 引入额外的数据源,例如可以从 wikipedia 上引入文本描述,或干脆从额外数据集例如 ImageNet-21K 来做预训练;
- 设计给原始图像数据引入额外的 tag 或标注。
从以上的方式我们看到,大部分引入额外数据信息的工作都是围绕文本特征进行展开,本质原因是输入的文本本身 “{class_name}” 或 “a photo of a {classname}” 包含信息太少了!其丰富度要远低于图像,通过额外的域内文本信息的引入,可以显著增强文本特征的质量。【所以 text feature 的质量是关键】
同时,可以看到,围绕图像分支改进的工作是相对较少的,于是很自然的想法,我们是不是也可不可以用同样的思路来增强 image feature 呢?
诶,这个方法好!因为互联网内往往存在非常大量的图像数据,很容易获取。但问题是这些图像往往是没有标注的,没有标签可以进行监督学习;另外,如果要人工进行标注的话,往往需要消耗不少人力物力。这明显限制了这种方式的应用。
利用原始CLIP自身的信息
在 Prompt learning 中,learnable prompt 的参数量是相对较少的,在经过大量 base class 数据训练之后,模型会对 base class 数据存在过拟合,丧失对 novel class 的泛化性能。要解决这个问题,一种非常有效的做法就是利用 vanilla CLIP 来约束带有 prompt 的模型的学习。
以ICCV’23 上的 PromptSRC 工作为例,如图3所示:
其中:
蓝线部分,就是原始 CLIP 的前向计算路径,分别会得到对应的 image 和 text feature。
灰线部分,就是带有 learnable prompt 的计算过程,也会得到对应的 feature。
在两条线的末尾,计算了三个 loss,这里就是用原始 CLIP 产生的 image 和 text feature 来约束由含有 learnable prompt 产生的 image 和 text feature。通过这样的约束,限制了 prompt 向着 base class 过拟合,达到 了SOTA的性能。
鉴于这个工作我们就想,如果换一个更好的模型来做约束性能会不会表现得更好?
方法
通过上面的动机分析,我们不难看出,PromptKD 的目标就是想引入更大的 CLIP 模型作为 teacher,以此来解决上面提到的三个问题。
-
重用 teacher CLIP 产生的 text feature 用于学生的训练和推理。这样一来可以keep 住 text feature 的质量,二来还可显著的节省计算量,因为训练时只涉及 student 的 image encoder,不会用到 teacher 的图像分支。
-
对齐学生 CLIP 和教师 CLIP 的 logits,即让大的 CLIP 模型给小的学生 CLIP 模型提供更好的监督。
-
因为有了教师 CLIP 的存在,就解决了数据量限制的问题,我们可以用大量的无标签 domain data 来训学生,不再拘泥于原来有限的有标签数据。在训练时,我们直接可以使用数据集的全量数据作为无标签数据进行蒸馏,这样一来就 prompt 就可以学到更广泛的 domain knowledge。同时高性能的教师 CLIP 也保证了用于蒸馏的软标签的准确性。
先看下上图(a),黄色部分代表的就是教师 CLIP,在教师 CLIP 经过训练之后,直接一次 forward,得到并保存下来对应类别的 text feaure,也就得到了图中的 Pre-stored Text Feature。
蓝色部分代表的是学生 CLIP,这里其实就只有一个 image encoder,在带有 learnablr prompt 的输入进入 image encoder 之后会得到对应的image feature,这是因为与teacher text feature在维度上不匹配,所以经过一个Projector,将512转成768维的特征。然后再与 Pre-stored Text Feature 相乘,得到logits。
最后,再进行蒸馏即可。具体流程可以总结如下:
- 第一阶段,教师模型的预训练。在这里,我们选择之前的SOTA方法PromptSRC去预训练我们的教师ViT-L/14 CLIP模型,我们的学生模型是ViT-B/16 CLIP模型。
- 第二阶段,学生CLIP模型的蒸馏。
- 第三阶段,学生的推断。
注意,这里的预训练不是必须的一步,选择去预训练教师模型,是为了让教师有一个更好的性能,从而有更好的学生蒸馏结果。如果直接使用vanilla ViT-L/14 CLIP作为教师,相比于baseline,也取得了明显的性能提升,具体结果请参考表4。
实验
本文提出的 PromptKD 方法在 prompt learning 的 11 个 benchmark dataset 上均达到了 SOTA 的性能。
Base-to-novel
Cross-dataset
Ablation
注:为了实验快速进行,消融实验里使用的不是全量数据集,而是 64 shots per class 进行的训练。所以会与表1中的数据相比略低。
在PromptKD中,任意类型的ViT-L/14 CLIP教师模型都可以蒸馏出一个很好的ViT-B/16 CLIP模型,相比于baseline (70.22 HM)都有明显的提升。
这里有一点非常有意思的是,我们可以看到,第四行的Teacher(CLIP) ViT-L/14也就是原始的CLIP模型,在经过PromptKD的蒸馏之后,我们的ViT-B/16 CLIP的结果(表1(b))明显超过了原始的ViT-L/14 CLIP模型。(77.62 vs. 76.52)
最后,再探讨下不同容量教师模型的选择。如表5所示,绿色代表学生ViT-B/16 CLIP的HM分数,土黄色代表教师的HM分数。可以看出:教师的性能越高,越能训练出更好的学生。
总结
本文介绍了一个用于视觉-语言模型的两阶段无监督提示蒸馏框架。该框架旨在通过使用未标记的领域数据,将大型CLIP教师模型的知识转移给轻量级CLIP学生模型,通过提示模仿。首先在领域少样本标记数据上对大型教师模型进行预训练,然后在大量未标记的领域数据上执行学生提示蒸馏。通过利用CLIP独特的解耦模态特性,我们提出重用预存的教师文本特征,并将其合并到学生图像编码器中,用于蒸馏和推理。
通过对11个识别数据集进行的大量实验表明了我们方法的有效性。但是,蒸馏方法的有效性与通过未标记领域样本传递的知识密切相关。当蒸馏数据缺乏来自目标领域的代表性时,蒸馏后的学生模型对该特定领域的泛化能力可能会出现偏差或削弱。未来,我们计划探索潜在的正则化方法以减轻这些问题。
最后,欢迎对多模态学习技术感兴趣的同学添加小编微信:cv_huber,备注"交流学习",一同加入群聊讨论学习。