CVPR 2024 | PromptKD:基于提示的视觉语言模型蒸馏

点击下方卡片,关注“CVer”公众号

AI/CV重磅干货,第一时间送达

点击进入—>【扩散模型和多模态】交流群

添加微信:CVer444,小助手会拉你进群!

扫描下方二维码,加入CVer学术星球可以获得最新顶会/顶刊上的论文idea和CV从入门到精通资料,及最前沿应用!发论文搞科研,强烈推荐!

169dd8104a8c0ee2292757981125e977.jpeg

作者:Zheng Li(已授权转载)

https://zhuanlan.zhihu.com/p/684269963

d39343628c19c892bf88f101299aac46.jpeg

大家别只点收藏,多点点赞~

《PromptKD: Unsupervised Prompt Distillation for Vision-Language Models》

主页:https://zhengli97.github.io/PromptKD/

代码:https://github.com/zhengli97/PromptKD

论文:https://arxiv.org/abs/2403.02781

一句话概括:

PromptKD是一个简单有效的基于prompt的视觉语言模型蒸馏新方法,在prompt learning的11个benchmark数据集上大幅领先,达到了SOTA。

大白话背景介绍

已经很了解VLMs和prompt learning的同学可以直接跳过,到背景问题~

什么是视觉-语言模型(Vision-Language Models, VLMs)?

视觉语言模型VLM一般由两个部分构成,即视觉(Vision)部分和语言(Language)部分。

以一个经典的VLM网络 CLIP[1] 的结构为例:

496d90c3090042179538c6ab0520fca7.jpeg

图1. CLIP架构。图片来自于CLIP论文。

如图1所示,CLIP由text branch和image branch组成。

其中, text branch主要由transformer构成,当要进行cls_num个类的分类任务时,会取每个类别对应的名称,如"plane", "car", "dog",与"a photo of a"进行组合,作为prompt输入进text encoder,得到大小为[cls_num, feat_dim]的text feature。

image branch的核心就是对输入的图像提取image feature,其通常为ResNet或者ViT[2]。图像经过image encoder之后得到image feature,其大小为[batch_size, feat_dim]。

将两个feature进行相乘就得到了预测logits。

CLIP有两个明确的特性,是这个工作的基础:

  1. CLIP可以进行zero-shot分类,即对未见过的类别进行识别,并保持很高的性能。而传统的CNN或者ViT由于模型架构限制不可以。

  2. 对于已知的类别,CLIP的text branch只需要一次forward就可以得到对应text feature用于分类。

什么是提示学习(Prompt Learning)?

在Text Branch部分中,a photo of a {class_name} 这样的描述太过宽泛,明显不是最优的。例如对于图2(b)的花,手工设计的a flower photo a {class}要描述的更加精确,其产生的结果就更好。

c399e3b62665d5691d72443051ff3925.jpeg

图2. 蓝色方块代表手动设计的prompt,绿色方块代表网络学习得到的learnable prompt。绿色方块acc超越了蓝色。图片来自于CoOp论文。

这就产生来两个问题,第一,固定模板的prompt不是最优的。第二,针对性的手工设计费时费力,且无法泛化。

于是,提示学习(Prompt Learning)[3] [4]就提出将prompt变成了一种learnable的方式,通过优化的方法让prompt在下游数据集上学习适用的表征,来替代手工设计的prompt,参考图2中的绿色方块。

这样优势是,可以在少量数据的情况下,仅通过引入一少部分的可学习参数(即learnable prompt),就可以将原始的CLIP快速适用到下游的任务/数据,同时在性能上比全参数微调的结果更好[4]

实验衡量指标是什么?

有三个指标,分别是base acc,novel acc和harmonic mean。

以imagenet-1k数据集为例,会取1000类中的前500类作为base class,后500类作为novel class。模型在base class上训练,完成后在base class和novel class上测试acc性能。因为novel class与base class数据类别不重复,所以novel acc可以有效反应模型泛化性能。harmonic mean指标是对base acc和novel acc的综合反映,为harmonic mean = (2*base acc*novel acc) / (base acc+novel acc)。总体的harmonic mean值越高,模型综合性能越好。

背景问题

prompt learning的核心作用是,保持原始CLIP参数不变,通过引入小部分learnable prompt参数,来将大的原始的经过预训练的CLIP模型适用到下游任务/数据上,提升CLIP模型在下游任务的性能,同时保持CLIP模型zero-shot能力。

除去一直发展至今的各种设计prompt形式的工作[3] [5] [6] [7] [8] [9] [10] [11] [12] [13],现如今最前沿的prompt learning方法主要还可以分为另外两类:

1. 引入额外数据/信息。这一类工作核心就是通过引入额外的数据或信息,做法包括但不限于,

(1) 通过LLM来生成{class_name}相关的语句,获得额外的有关{class_name}的特性 特征[14] [15][16],或者更多描述性语句[17] [18] [19] [20]

(2) 引入额外的数据源,从wikipedia上引入文本描述[21],从额外数据集例如ImageNet-21K来做预训练 [22]

(3) 设计给原始图像数据引入额外的tag或标注[23] [24] [25]

从以上的方式我们看到,大部分引入额外数据信息的工作都是围绕text branch展开,本质原因是输入的text本身"{class_name}"或"a photo of a {classname}"包含信息太少,丰富度要远低于image,通过额外的域内文本信息的引入,可以显著增强text feature的质量。所以text feature的质量是关键。

同时,可以看到,围绕image branch的工作是相对较少的。这时候问题就来了:那我们可不可以用同样的思路来增强image feature呢?

诶,这个方法好!因为互联网内往往存在非常大量的图像数据,很容易获取。

但问题是这些图像往往是没有标注的,没办法用gt训,如果要去进行标注,需要消耗很多的时间或者钱。明显限制了这种方式的应用。

2. 利用原始CLIP自身信息约束模型学习[19] [26] [27] [28] [29] [30] [31],防止过拟合。

在Prompt learning中,learnable prompt的参数量是相对较少的,在经过大量base class数据训练之后,模型会对base class数据存在过拟合,丧失对novel class的泛化性能。要解决这个问题,一种非常有效的做法就是利用vanilla CLIP来约束带有prompt的模型的学习。

以ICCV 23 PromptSRC为例,如图3所示,

89862fd10aad1579d1464e029818f721.jpeg

图3. PromptSRC结构图。图片来自于PromptSRC论文。

图3这篇工作就看两条线,蓝线和灰线。

蓝线,就是原始CLIP的前向计算路径,分别会得到对应的image和text feature。

灰线,就是带有learnable prompt的计算过程,也会得到对应的feature。

在两条线的末尾,计算了三个loss,这里就是用原始CLIP产生的image和text feature来约束由含有learnable prompt产生的image和text feature。通过这样的约束,限制了prompt向着base class过拟合,达到了SOTA的性能。

由这个工作我们就想,如果换一个更好的模型来做约束是不是性能会更好?

于是,这就引出了我们的工作。

方法

PromptKD其实核心就在做一件事,引入更大的CLIP模型作为teacher,解决了上面提到的三个问题。

(1) 重用(Reuse) teacher CLIP产生的text feature用于学生的训练和推断。这样确保了text feature高质量的同时,还显著的节省计算量,训练时只涉及student的image encoder。

(2) 对齐学生CLIP和教师CLIP的logits。让大的CLIP模型给小的学生CLIP模型提供更好的监督。

(3) 因为有了教师CLIP的存在,就解决了数据量限制的问题,我们可以用大量的无标签domain data来训学生,不再拘泥于原来有限的有标签数据。在训练时,我们直接可以使用数据集的全量数据作为无标签数据进行蒸馏,这样一来就prompt就可以学到更广泛的domain knowledge。同时高性能的教师CLIP也保证了用于蒸馏的软标签的准确性。

我们先来看一个简单的结构缩略图:

dce20debca1764c2fd2876465f759330.jpeg

图4. PromptKD框架简略图。

黄色的方块部分代表的就是教师CLIP,在教师CLIP经过训练之后,直接一次forward,得到并保存下来对应类别的text feaure,也就得到了图4中的Pre-stored Text Feature。

蓝色的方块代表的是学生CLIP,这里其实就只有一个image encoder,在带有learnablr prompt的输入进入image encoder之后会得到对应的image feature,这是因为与teacher text feature在维度上不匹配,所以经过一个Projector,将512转成768维的特征。然后再与Pre-stored Text Feature相乘,得到logits。

然后进行蒸馏。

完整的框架图如图5所示:

a3a79e4e80c20082dfe3c2c9d4e4282d.jpeg

图5. PromptKD整体框架图。

图5里就是图4过程的细化。

这里将PromptKD的每个阶段都进行了详细的阐明。大家看图就明白了~

第一阶段,教师模型的预训练。在这里,我们选择之前的SOTA方法PromptSRC去预训练我们的教师ViT-L/14 CLIP模型,我们的学生模型是ViT-B/16 CLIP模型。

注意,这里的预训练不是必须的一步,选择去预训练教师模型,是为了让教师有一个更好的性能,从而有更好的学生蒸馏结果。如果直接使用vanilla ViT-L/14 CLIP作为教师,相比于baseline,也取得了明显的性能提升,具体结果请参考表4。

第二阶段,学生CLIP模型的蒸馏。

第三阶段,学生的推断。

最后再来一个简洁明了的流程概括图:

151e916d118abd9359a31c1cfddd4b30.jpeg

图6. 计算流程

实验结果

我们的PromptKD方法在prompt learning的11个benchmark dataset上都达到了SOTA的性能。

Base-to-novel实验:

a4a58782be45e5b8b61bdeeae9aec96c.jpeg

表1. Base-to-novel实验结果。

aada98b7874dc662e5ac2ab60ec17dfb.jpeg

图7. HM分数在11个数据集上的总揽图。

Cross-dataset实验

86afddedc6a29ed65234eeabacb60cd9.jpeg

表2. Cross-dataset实验结果。

消融实验

为了实验快速进行,消融实验里使用的不是全量数据集,而是64 shots per class进行的训练。所以会与表1中的数据相比略低。

与其他同样使用了无标签数据的工作的性能对比

8044d6c871c1a676bfe0945b2428d842.jpeg

表3. 在Flowers102数据集上与使用了无标签数据的其他方法的对比结果。

教师预训练方法的选择

在PromptKD中,任意类型的ViT-L/14 CLIP教师模型都可以蒸馏出一个很好的ViT-B/16 CLIP模型,相比于baseline (70.22 HM)都有明显的提升。

这里有一点非常有意思的是,我们可以看到,第四行的Teacher(CLIP) ViT-L/14也就是原始的CLIP模型,在经过PromptKD的蒸馏之后,我们的ViT-B/16 CLIP的结果(表1(b))明显超过了原始的ViT-L/14 CLIP模型。(77.62 vs. 76.52)

86718229c88c9dd6a17c36f8dae30cd8.jpeg

表4. 不同教师预训练方法对PromptKD蒸馏效果的影响。

不同容量教师模型的选择

如表5所示,绿色代表学生ViT-B/16 CLIP的HM分数,土黄色代表教师的HM分数。教师的性能越高,越能训练出更好的学生。

878d8e70f2f288d61af2a7ddfda08217.jpeg

图8. 不同容量的CLIP模型作为教师进行蒸馏。

欢迎大家试用PrompKD~

Acknowledgement

这篇论文解读感谢师弟武戈同学的部分论文总结,PromptKD这篇工作也非常感谢蚂蚁的申书恒,张长浩和傅幸同学的讨论和帮助。

何恺明MIT授课的课件PPT下载

在CVer公众号后台回复:何恺明,即可下载本课程的152页课件PPT!赶紧学起来!

CVPR 2024 论文和代码下载

在CVer公众号后台回复:CVPR2024,即可下载CVPR 2024论文和代码开源的论文合集

多模态和扩散模型交流群成立

 
 
扫描下方二维码,或者添加微信:CVer444,即可添加CVer小助手微信,便可申请加入CVer-多模态和扩散模型微信交流群。另外其他垂直方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch、TensorFlow和Transformer、NeRF等。
一定要备注:研究方向+地点+学校/公司+昵称(如多模态或者扩散模型+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群

 
 
▲扫码或加微信号: CVer444,进交流群
CVer计算机视觉(知识星球)来了!想要了解最新最快最好的CV/DL/AI论文速递、优质实战项目、AI行业前沿、从入门到精通学习教程等资料,欢迎扫描下方二维码,加入CVer计算机视觉(知识星球),已汇集近万人!

▲扫码加入星球学习
 
 
▲点击上方卡片,关注CVer公众号
整理不易,请点赞和在看
  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值