【提示学习论文】KDPL:Improving Zero-shot Generalization of Learned Prompts via Unsupervised Knowledge Distil

Improving Zero-shot Generalization of Learned Prompts via Unsupervised Knowledge Distillation(ECCV 2024)

  • 利用无监督知识蒸馏改进学习提示的zero-shot泛化
  • image+text
  • 佛罗伦萨大学、比萨大学
  • 代码:https://github.com/miccunifi/KDPL

1 KDPL

![[KDPLf2.png]]

教师模型

  • 冻结的text encoder:输入 a photo of a class,得到文本特征ψi,T,B
  • 冻结的image encoder:输入图像,得到图像特征ψI,B
  • 计算教师概率pT
    在这里插入图片描述

学生模型

  • text encoder:输入class+learnable prompts,得到ψi,T,S​
  • image encoder:输入图像和prompts,得到ψI,S
  • 计算学生概率pS
    在这里插入图片描述

知识蒸馏

将教师模型概率pT与学生模型的概率pS进行对比,通过对称KL散度损失函数进行知识蒸馏,更新学生模型的提示γ

2 标签不可知的提示学习

1 KDPL overview

可以在没有类别名称或标签信息的情况下,与任意现有的提示学习方法无缝集成。

  • 标签不可知:不使用真实标签,但假设知道训练数据集中的类别名称
  • 类别不可知:更进一步,假设训练类别名称也是不可知的。此时从包含大约20k个类别名称的大词典中自动筛选类别。

2 KDPL

  • 训练过程中不使用图像的真实标签,但我们知道训练数据集中存在的类别名称。
  • 例如我们知道fish、cat、cow
  • 使用教师模型进行zero-shot分类,计算每个类别的概率pT
  • 学生模型计算已知类别的概率分布pT
  • 通过对齐KL散度损失函数优化学生模型提示

怎么对齐的?损失函数计算的部分,比如fish的与哪个对齐??
老师和学生模型都输入了训练集类别名称,一样的class进行对齐就好。

使用KL散度,计算老师预测概率与学生预测概率之间的损失:
![[KDPLg3.png]]在这里插入图片描述

  • 蒸馏损失的计算,取决于老师的固定预测、学生的即时预测和类别集合C。
  • KDPL可以用于标签不可知类别不可知的适应场景。我们实验发现,对称KL散度略优于任何一种非对称选项。

3 类别不可知的提示学习

  • 我们不知道训练数据集中所有类别的名称。

  • 从一个大型词汇表Open Images V7 dataset(包含20k类别)中自动选择每个批次最相关的类别。

  • 给定一个图像批次X=Ii N,和所有类别名称C,使用教师模型对所有图像和词汇表中的所有类别(20k)进行推理,得到概率pT

  • 生成概率矩阵PT:对于图像批次的每张图,计算每个类别的概率,沿着批次维度堆叠,得到概率矩阵:![[KDPLg1.4.png]]

  • 沿批次轴计算平均概率,得到![[KDPLg1.5.png]],表示每个类别在整个批次中的平均概率

  • 根据平均概率,选择K个最高类别,作为学生模型的输入

4 实验

实现细节:

  • CLIP ViT-H-14作为老师模型
  • ViT-B/32作为学生模型
  • K=1000

域泛化:
![[KDPLt1.png]]

  • 17
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值