Open-Vocabulary Multi-Label Classification via Multi-modal Knowledge Transfer 论文解读
前言
这篇文章应该是离我要做的东西最近的文章了,这是我需要去follow的一篇,来自AAAI2023。之前已经仔细看过一次,现在又回头看第二次了。我需要对这篇文章进行一个详细的汇报与分享,希望自己能做好!什么时候才能在这篇的基础上进行创新呢?拭目以待!!
Motivation
multi-label zero-shot learning(ML-ZSL)的方法可以更好地满足真实应用的需求,因为它能够预测从没见过的标签。文中对于目前的ML-ZSL的方法提出了存在的2个问题:
- 仅仅使用了单模态预训练模型(一般是语言预训练模型),忽略了图像-文本对的视觉语义信息。个人认为应该是文本图像不能够很好地、很正确地、很迅速的match在一起,因为先验知识掌握得不足。如果两个模块分开用两种不同的预训练模型,因为它们之间没有知识互通,所以会导致虽然两个模型都能很好地学习到图像/文本特征,但在跨模态的matching中还是会表现得不够好。
- 预训练语言模型只能够较好地处理word label,而不能拓展到text label(例如word label ‘cat’ & text label ‘black cat’,后者会带有一定的视觉属性。
上图中对比了一般的ML-ZSL框架和文章提出的MKT架构。在这个图中也可以看出来传统ML-ZSL的缺陷,主要是Image Encoder和Label Encoder完全是分开的,没有知识互通,在跨模态的比较中可能会遇到比较大的alignment问题。第二个点就是关于text label (‘Black Dog’),像Glove这样的预训练embedding模型只能是把它拆开,然后用两个word embedding叠在一起作平均。如果没有图像的帮助,最后在语义空间中或许不能作很好地匹配。文中MKT模型使用了VLP (Vision and Language Pre-trained) 模型,因为是用文本和图像对一起训练的,所以有知识的共享,更好的alignment和matching。
同时,open-vocabulary方法取得了很好的成果。open-vocabulary是在大量的文本-图像对上面去训练的,正好能解决单模态预训练模型的缺陷,因为它达成了图像与语言模型的知识互通。下面来区分一下open-vocabulary和zero-shot的细微差别:
(from Open-Vocabulary Object Detection Using Captions, CVPR2021)
- v Ω v_{\Omega} vΩ: entire language vocabulary
- v B v_{B} vB: classes for training (mainly for vision model)
- v C v_{C} vC: open set for training a large model (teacher model)
- v T v_{T} vT: classes for testing
对于open-vocabulary任务来说,一般需要使用很大的视觉-文本对,训练一个大模型;然后再使用这个大模型去完成其他下游的训练任务,或者用这个大模型来指导其他任务模型的训练。Open-Vocabulary其实是更加泛化的zero-shot learning,它和传统的ZSL区别在于,它能够识别任意的类别,而非同一领域相似但未见的类别。和Weakly Supervised区别在于,WS任务是需要已知训练的类别,它需要包含在这个open set中,而open-vocabulary的测试类别是不可知的,它可以是open set中任意的子集,也可以不是。
例如说,传统的ZSL任务就是,在cifar10的其中几类上训练,然后测试的时候能识别另外的几类。但open-vocabulary因为具有大量数据训练的模型,在它的指导和帮助下能够识别出与训练数据分布相差较大的类别,如cifar10训练,识别杯子、台灯这种十分不相关的类别。
Contribution
- open-vocabulary之前没有用于多标签分类任务,这篇文章首次探索这一应用,提出的MKT (multi-model knowledge transfer) 联合应用了图像-文本对的语义多模态信息。
- MKT中应用了知识迁移,引入了open-vocabulary的前置知识,用来保证图像与文本的matching。为了适应不同的下游任务,对text embedding使用了prompt tuning。使用two-stream module, 捕捉局部与全局的特征,实现更好的特征提取功能
- 在NUS-WIDE和Open Images上取得sota
Challenge
- The alignment between image and label embedding
- The relation between seen and unseen label embedding
- word-level label to text-level label
Method
文中将训练过程分成了2 stage。第一个阶段主要是训练ViT和Two-stream模块的参数,第二个阶段采用了prompt tuning,详情可见 Learning to Prompt for Vision-Language Models.
Stage 1
-
1 — 图片先划分为non-overlapping patch,外加一维class token指代分类,然后放入Vision Transformer,得到下面输出:
x L = [ o c l s , o p a t c h ] x_L=\left[o_{c l s}, o_{p a t c h}\right] xL=[ocls,opatch]
假设图像大小224×224×3,用一个16x16尺度,步长为16,通道数为768的卷积过一下图像,然后每一个patch大小为14×14,进入Transformer先拉平成196长度,然后外加一个class token变成197×768的大小。过完Transformer后还是保持一样的大小。 -
2 — 同一张图放入VLP Image Encoder,得到一个向量输出(大小是1×768),因为这里用的CLIP模型的Image Backbone也是ViT,但官方模型只输出一维图像embedding。在这里将VLP Image Encoder当作老师模型,用该输出和上面计算出的 O c l s O_{cls} Ocls进行距离计算(1范式),获得只是蒸馏部分Loss。
L dist ≜ ∥ Φ I C L I P ( x ) − o c l s ∥ 1 = ∥ o d i s t − o c l s ∥ 1 \mathcal{L}_{\text {dist }} \triangleq\left\|\Phi_I^{C L I P}(\mathrm{x})-\mathbf{o}_{c l s}\right\|_1=\left\|\mathbf{o}_{d i s t}-\mathbf{o}_{c l s}\right\|_1 Ldist ≜ ΦICLIP(x)−ocls 1=∥odist−ocls∥1 -
3 — 一个two stream Module,这也是和普通ViT模型较大的不同,就是利用好patch embedding,而普通ViT模型只是利用那个分类 embedding。这里让patch embedding过一个local线性层,因为patch主要捕捉的是local information;让class embedding过一个global线性层,将它们的维度都映射成public embedding space 需要的维度(512维)。
-
4 — 用VLP Text Encoder把prompt template转换为text embedding,维度默认为512维。
prompt template: 将下游任务转换为pre-training的任务。先制定一个句子模板,比如说文中是" There is a [mask] in the scene",mask部分填入的是标签词,后面需要做的是标签词映射,即给 定哪些标签词是positive的,哪些是negative的。引入模板和标签词本质可理解为一种数据增强, 通过增加提示的方式引入先验知识。
-
5 — 计算similarity score,对于class embedding, 在公共空间中直接和text embedding做cosine计算相似度;对于patch embedding,所有patch都要和text embedding做点积,但是只选出前K个得分较高的,并且加起来取平均。
s i = ⟨ z i , e c l s ⟩ + TopK ( [ ⟨ z i , e 1 ⟩ , ⟨ z i , e 2 ⟩ , … , ⟨ z i , e N ⟩ ] ) s_i=\left\langle\mathbf{z}_i, \mathbf{e}_{c l s}\right\rangle+\operatorname{TopK}\left(\left[\left\langle\mathbf{z}_i, \mathbf{e}_1\right\rangle,\left\langle\mathbf{z}_i, \mathbf{e}_2\right\rangle, \ldots,\left\langle\mathbf{z}_i, \mathbf{e}_N\right\rangle\right]\right) si=⟨zi,ecls⟩+TopK([⟨zi,e1⟩,⟨zi,e2⟩,…,⟨zi,eN⟩])
上面的 s i s_i si是该样本属于第i类标签的score, z i z_i zi是在embedding space中的text embedding, e c l s e_{cls} ecls和 e i e_i ei是经过two-stream出来的image embedding。 -
6 — 计算Ranking loss。
L rank ≜ ∑ i ∑ p ∈ y i , n ∉ y i max ( 1 + s i n − s i p , 0 ) \mathcal{L}_{\text {rank }} \triangleq \sum_i \sum_{p \in \mathbf{y}_i, n \notin \mathbf{y}_i} \max \left(1+s_i^n-s_i^p, 0\right) Lrank ≜i∑p∈yi,n∈/yi∑max(1+sin−sip,0) -
7 — 合并ranking loss 和 distillation loss
L stage 1 = L rank + λ L dist \mathcal{L}_{\text {stage } 1}=\mathcal{L}_{\text {rank }}+\lambda \mathcal{L}_{\text {dist }} Lstage 1=Lrank +λLdist
Stage 2
等到Stage 1中主要模型的参数都训练好之后固定住,去训练一个表现较好的prompt。这里的prompt不是采用人工制订的prompt template( like ‘This is a [label].’ ),而是把[label]前面的字都变成是可训练的。所以这个阶段我们主要训练一个prompt template类似 V 1 V 2 V 3 . . . V N [ l a b e l ] V_1V_2V_3...V_N[label] V1V2V3...VN[label] 这样的,可以是对于所有类别都是这个template, 这样能更好适应我们设定的任务。这种方法不需要人工制订,减少人力成本,也去除了我们需要大量领域知识的要求。
在这个阶段,我们的目标也是让图像和文本对尽可能正确匹配,但最后的Loss是有区别的,因为没有知识蒸馏操作。
L
stage
2
=
L
rank
\mathcal{L}_{\text {stage } 2}=\mathcal{L}_{\text {rank }}
Lstage 2=Lrank
Experiment
- 数据集:NUS-WIDE, Open-Images v4
- 任务:zero-shot learning( ZSL ), generalized zero-shot learning (GZSL)
- 对比模型:LESA, GAN-MLZSL, ZS-SDL, BiAM, CLIP-FT, MKT
对于NUS-WIDE和Open-Images采用不同的K。比较这个最新提出的BiAM模型,以及训练好的CLIP-FT模型都有提升。比BiAM在两个数据集上分别提升了2.5%和16.3%,同时比同样是open-vocabulary的CLIP-FT也有一部分提升。
Ablation Studies
- 探究了distill 和 prompt tuning 各自起到的作用和组合起到的作用,两个组合才能发挥更大作用(NUS-WIDE)
prompt tuning还可以更好地识别视觉信息,下面是做了tuning前后的对比:
- 采用不同的Label embedding方法(没有distill和prompt tuning)
在label retrival任务中,探索CLIP和它的prompt tuning版本能否更好地捕捉标签之间的相关性。图中的Girls和Airport应该是没有出现在可见标签中的,获取图像的Top-3标签,与真实标签做一下对比。CLIP+prompt tuning能够非常准确捕捉标签的相似信息。
下面这张图用t-SNE的方法体现了在2维上label embedding的相似性,也再次证明了CLIP能够更好地捕捉相似性。
- 探究two stream module的作用,local head在F1上表现好,global head在mAP上表现好。因为global head是一个更加全局的通用的、通用的表示。
- 探究超参数
λ
\lambda
λ 和local head中的
t
o
p
−
k
top-k
top−k
- 探究不同的backbone- VGG19 & ResNet50
- 标准的多标签分类任务对比
- 与BiAM的Attention map比较