[ICLR 2023] LPT: Long-tailed Prompt Tuning for Image Classification

Introduction

  • 作者提出 Long-tailed Prompt Tuning (LPT),通过 prompt learning 来解决长尾问题,包括 (1) 使用 shared prompt 学习 general features 并将预训练模型 adapt 到 target domain;(2) 使用 group-specific prompts 学习 group-specific features 来提高模型的 fine-grained discriminative ability

在这里插入图片描述

Preliminary Study

Performance Investigation of VPT (Visual Prompt Tuning)

  • 作者首先通过对比 VPT (Visual Prompt Tuning) 和 linear probing 在 Places-LT 数据集上的精度来说明 prompt tuning 对长尾数据集是有效的 (VPT 的输入为 input tokens 加上 learnable prompts (tokens),同时和 linear probing 一样在预训练模型最后加上 linear classifier)
  • 从下表中可以看出:a) prompt tuning 可以持续提高模型的 LTR 性能;b) prompt tuning 对长尾分布具有鲁棒性,能更好地学习尾部类别。同时也可以注意到,简单的 prompt tuning 并不能直接让模型在长尾数据集上达到 SOTA
    在这里插入图片描述

Analysis of Prompt Tuning

  • 作者接下来分析了为什么 prompt tuning 适合长尾识别 (但仍然没有从原理上分析为什么)
  • 由下图的 LDA 可视化可以看出 (use the pretrained ViT-B and the ViT-B fine-tuned by VPT on Places-LT to extract features of ImageNet val set and Places-LT val set),prompt tuning 可以很好地将下游任务数据分布 (Places-LT) 和预训练数据分布 (ImageNet) 对齐,可以更好地让预训练模型 adapt 到长尾任务的 target domain (from domain adaptation perspective)
    在这里插入图片描述
  • 作者计算了 ViT-B 和 VPT 输出特征的平均类内距离、平均类间距离以及两者之商 γ \gamma γ,可以看到,VPT 的平均类内距离和 γ \gamma γ 都更小,KNN 分类准确率更高,说明 VPT 输出的特征更具有区分度
    在这里插入图片描述

Long-tailed Prompt Tuning (LPT)

在这里插入图片描述

Phase 1: Shared Prompt Tuning

在这里插入图片描述

  • 类似于 VPT-Deep,给 ViT 的 L L L 层都各自加上额外的 prompts,因此 phase 1 需要优化 shared prompt u = [ u 1 , . . . , u L ] \mathbf u=[\mathbf u_1,...,\mathbf u_L] u=[u1,...,uL] 和 cosine classifier f f f,其中 shared prompt 用于学习所有类别的共同特征,并带来了上节讨论的 prompt tuning 的各种好处,包括 domain adaptation 和输出更具区分度的特征
  • 每层里的前向过程
    在这里插入图片描述在这里插入图片描述在这里插入图片描述其中, c \mathbf c c 为 [CLS], z \mathbf z z 为 token embed. 新添加的 prompts 不需要计算对应的自注意力输出,只需要作为 key 和 value 与 token embed 做交互即可
  • 损失函数
    在这里插入图片描述在这里插入图片描述

Phase 2: Group Prompts Tuning

在这里插入图片描述

  • 作者在 phase 2 加入了 m m mgroup-specific prompts R = { ( k 1 , r 1 ) , . . . , ( k m , r m ) } \mathcal R=\{(\mathbf k_1,\mathbf r^1),...,(\mathbf k_m,\mathbf r^m)\} R={(k1,r1),...,(km,rm)} 用于学习 group-specific knowledge 从而增强模型的 fine-grained discriminative ability,其中 k i \mathbf k_i ki i i i-th group 的 key, r i \mathbf r^i ri i i i-th group 的 prompts,包含 L − K L-K LK 个 prompt 序列 (只在后 L − K L-K LK 层使用 group-specific prompts).
  • Phase 2 包含两个步骤:(1) 冻住 shared prompts,经过 L L L 层推理得到 c L \mathbf c_L cL 作为 query q \mathbf q q m m m 个 keys 计算余弦相似度,选出相似度最高的 k k k 个 groups
    在这里插入图片描述然后对选出的 k k k 个 groups 的 prompts 进行 prompt ensembling
    在这里插入图片描述(2) 重新使用步骤 (1) 在前向传播中得到的 ( c K , z K ) (\mathbf c_K,\mathbf z_K) (cK,zK),在后 L − K L-K LK 层重新进行前向传播,每层的输入包括 [CLS] embed c \mathbf c c、patch embed z \mathbf z z、shared prompt u \mathbf u u 和 group-specific prompt r \mathbf r r,每层里的前向过程为
    在这里插入图片描述在这里插入图片描述在这里插入图片描述
  • 损失函数
    在这里插入图片描述在这里插入图片描述其中, β \beta β 为 scale factor,第二项损失函数被用于增大 q \mathbf q q 和其匹配的 k k k 个 groups 的 keys 之间的余弦相似度,这是由于 Phase 1 生成的特征已经比较 compact 并且在 Phase 2 是不变的,因此该损失项可以使得 keys 靠近特征空间中的不同聚类中心,使得不同 groups 对应不同的 group-specific feature
  • Dual Sampling. class-balanced sampling 和 instance-balanced sampling 分别容易使得模型对尾部和头部类别过拟合,作者采用 Dual Sampling,从 instance-balanced sampler 和 class-balanced sampler 分别采样一个 mini-batch { I } ins \{\mathbf I\}_{\text{ins}} {I}ins { I } bal \{\mathbf I\}_{\text{bal}} {I}bal. { I } bal \{\mathbf I\}_{\text{bal}} {I}bal 的损失函数对应 β = 1 \beta=1 β=1 时的 L P 2 \mathcal L_{\mathbf P_2} LP2 { I } ins \{\mathbf I\}_{\text{ins}} {I}ins 的损失函数对应 β = η ( E − e ) / E \beta=\eta(E-e)/E β=η(Ee)/E 时的 L P 2 \mathcal L_{\mathbf P_2} LP2,其中 η = 0.5 \eta=0.5 η=0.5 为 initialized weight, E E E 为总的训练 epoch 数, e e e 为当前 epoch 数

Loss Function

  • phase 1/2 中使用的 L cls \mathcal L_{\text{cls}} Lcls 采用 asymmetric GCL loss L A-GCL \mathcal L_{\text{A-GCL}} LA-GCL.
  • 首先根据 GCL 对 logits s ^ \hat {\mathbf s} s^ 进行加上 bias 和 rescale
    在这里插入图片描述其中, α \alpha α 为 scaling factor, ϵ \epsilon ϵ 为从高斯分布中采样的随机变量 ( ∥ ϵ ∥ \|\epsilon\| ϵ 为取绝对值), n i n_i ni 为训练集中类别 i i i 的样本数, n m a x n_{max} nmax 为训练集中的最大类别样本数. 对应的 per-class probability 为
    在这里插入图片描述
  • 然后根据 ASL 进行 Asymmetric Focusing
    L A − G C L = − y j ( 1 − p j ) λ + log ⁡ ( p j ) − ∑ 1 ≤ i ≤ C , i ≠ j y i ( p i ) λ − log ⁡ ( p i ) \mathcal{L}_{\mathrm{A}-\mathrm{GCL}}=-\mathbf y_{\mathrm j}\left(1-\mathbf{p}_{\mathrm{j}}\right)^{\lambda_{+}} \log \left(\mathbf{p}_{\mathrm{j}}\right)-\sum_{1 \leq \mathrm{i} \leq \mathrm{C}, \mathrm{i} \neq \mathrm{j}}\mathbf y_{\mathrm i}\left(\mathbf{p}_{\mathrm{i}}\right)^{\lambda_{-}} \log \left(\mathbf{p}_{\mathrm{i}}\right) LAGCL=yj(1pj)λ+log(pj)1iC,i=jyi(pi)λlog(pi)其中, j j j 为输入样本的标签类别, λ + = 0 , λ − = 4 λ_+=0,λ_−=4 λ+=0,λ=4 为 focusing parameter, y \mathbf y y 为 label smoothing 后的类别标签向量,即 y j = 0.9 + 0.1 / C , y i = 0.1 / C \mathbf y_{\mathrm j}=0.9+0.1/C,\mathbf y_{\mathrm i}=0.1/C yj=0.9+0.1/C,yi=0.1/C (疑问:ASL 本来是 BCE 上用的,但这里是 CE + label smoothing 之后再加上 ASL 的动态加权, ( 1 − p j ) λ + \left(1-\mathbf{p}_{\mathrm{j}}\right)^{\lambda_{+}} (1pj)λ+ 的意义和 ASL 一样,都是筛选出难样本,但感觉 ( p i ) λ − \left(\mathbf{p}_{\mathrm{i}}\right)^{\lambda_{-}} (pi)λ 的意义已经和 ASL 完全不同了,可以等进一步理解 label smoothing 为什么有用之后再来看)

Experiments

  • Model. ViT-B/16 with ImageNet-21k pretrained model.
  • Shared Prompt. default length of prompt as 10.
  • Group-specific Prompts. shared layer number K = 6 K = 6 K=6 and the size of prompt size m = 20 m = 20 m=20; for each prompt in the set, the prompt length is also set as 10 (Note that setting K = 6 K = 6 K=6 may lead to 1.5x inference cost compared to VPT). prompt ensemble number k = 2 k = 2 k=2.

Comparison with State-of-The-Art Methods

  • Comparison on Places-LT.
    在这里插入图片描述在这里插入图片描述
  • Comparison on CIFAR100-LT.
    在这里插入图片描述
  • Comparison on iNaturalist 2018.
    在这里插入图片描述

Robustness with Domain Shift

在这里插入图片描述

Ablation Study

  • Different Model Size and Pretrained Models.
    在这里插入图片描述在这里插入图片描述

  • Effect of Each Phase.
    在这里插入图片描述

  • Decoupled Training. during joint training, the shared prompt is still updated simultaneously, thus the query function is sub-optimal during training, resulting in worse matching results.
    在这里插入图片描述

  • Query Function and Group Size m m m.
    在这里插入图片描述when we further increase the size to 40, the final accuracy declines to 49.87%. A possible reason is that, some classes in the dataset may share some similar group-specific feature or knowledge在这里插入图片描述

  • Effect of K K K. K K K 过大会导致无法学得有效的 group-specific knowledge,过小会导致 Phase 2 匹配 groups 时无法充分利用 Phase 1 得到的 adapted feature representation
    在这里插入图片描述

  • Effect of Ensemble Number k k k.
    在这里插入图片描述

  • Effect of Asymmetric GCL Loss.
    在这里插入图片描述

  • Statistic of Prompt Matching.
    在这里插入图片描述

References

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Crossformer是一种利用交叉维度依赖性进行多元时间序列预测的Transformer模型。这个模型的动机是填补之前Transformer在处理多元时间序列时对不同变量之间关系刻画不足的问题。之前的Transformer更多地关注如何通过时间维度的注意力机制建立时序上的关系,而忽略了变量之间的关系。Crossformer通过引入时间维度和变量维度两个阶段的注意力机制来解决这个问题。特别是在变量维度上,Crossformer提出了一种高效的路由注意力机制。这篇论文填补了多元时间序列预测中变量关系建模的空白,并在ICLR2023中被提出。\[1\]\[2\]\[3\] #### 引用[.reference_title] - *1* *3* [【ICLR 2023】 CrossFormer增强多元时间序列建模能力](https://blog.csdn.net/qq_33431368/article/details/129483613)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [读论文《Crossformer:利用跨维度依赖进行多变量时间序列预测的Transform》](https://blog.csdn.net/vzvzvzv/article/details/131376526)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值