CoOp: Learning to Prompt for Vision-Language Models
论文地址:https://arxiv.org/pdf/2109.01134.pdf
CoOp的全称为Context Optimization,即上下文优化,其将CLIP中人工设置的Prompt,变为一个可学习的(learnable)prompt,并经过在11个下游任务上验证发现,CoOp+CLIP极大的提升了原CLIP的性能。
1.CoOp模型
有两种方法,分别为unified context和class-specific context,其中unified context指同一个数据集训练一个固定的context,而class-specific context是针对同一个数据集中的不同类别训练不同的context。并且作者把预测标签class放的位置有中间位置和末尾位置两种。
对于Unified Context,输入text encoder的prompt可表示为:
t
=
[
V
]
1
[
V
]
2
.
.
.
[
V
]
M
[
C
L
A
S
S
]
或
t
=
[
V
]
1
.
.
.
[
V
]
M
2
[
C
L
A
S
S
]
[
V
]
M
+
1
2
.
.
.
[
V
]
M
t=[V]_{1}[V]_{2}...[V]_{M}[CLASS]\quad 或\quad t=[V]_{1}...[V]_{\frac{M}{2}}[CLASS][V]_{\frac{M+1}{2}}...[V]_{M}
t=[V]1[V]2...[V]M[CLASS]或t=[V]1...[V]2M[CLASS][V]2M+1...[V]M
对于一个数据集,需要训练得到M个context token。其中[CLASS]在预测时更换成各个类的名称。在执行预测时,对类别i的预测概率为:
p
(
y
=
i
∣
x
)
=
e
x
p
(
<
g
(
t
i
)
,
f
>
)
/
τ
)
∑
j
=
1
K
e
x
p
(
<
g
(
t
j
)
,
f
>
)
/
τ
)
p(y=i|x)=\frac{exp(<g(t_{i}),f>)/\tau )}{\sum_{j=1}^{K}exp(<g(t_{j}),f>)/\tau )}
p(y=i∣x)=∑j=1Kexp(<g(tj),f>)/τ)exp(<g(ti),f>)/τ)
其中f为image feature,ti表示把[CLASS]换成第i类的名称。最终得到的结果就是计算当前image feature属于第i类的概率。
对于Class-Specific Context(CSC),输入的text encoder的prompt可表示为:
t
=
[
V
]
1
i
[
V
]
2
i
.
.
.
[
V
]
M
i
[
C
L
A
S
S
]
≠
t
=
[
V
]
1
j
[
V
]
2
j
.
.
.
[
V
]
M
j
[
C
L
A
S
S
]
或
t
=
[
V
]
1
i
.
.
.
[
V
]
M
2
i
[
C
L
A
S
S
]
[
V
]
M
+
1
2
i
.
.
.
[
V
]
M
i
≠
t
=
[
V
]
1
j
.
.
.
[
V
]
M
2
j
[
C
L
A
S
S
]
[
V
]
M
+
1
2
j
.
.
.
[
V
]
M
j
i
≠
j
a
n
d
i
,
j
ϵ
1
,
2...
K
t=[V]_{1}^i[V]_{2}^i...[V]_{M}^i[CLASS] \neq t=[V]_{1}^j[V]_{2}^j...[V]_{M}^j[CLASS]\\ 或t=[V]_{1}^i...[V]_{\frac{M}{2}}^i[CLASS][V]_{\frac{M+1}{2}^i}...[V]_{M}^i\neq t=[V]_{1}^j...[V]_{\frac{M}{2}}^j[CLASS][V]_{\frac{M+1}{2}}^j...[V]_{M}^j\\i\neq j \ and\ i,j\ \epsilon {1,2...K}
t=[V]1i[V]2i...[V]Mi[CLASS]=t=[V]1j[V]2j...[V]Mj[CLASS]或t=[V]1i...[V]2Mi[CLASS][V]2M+1i...[V]Mi=t=[V]1j...[V]2Mj[CLASS][V]2M+1j...[V]Mji=j and i,j ϵ1,2...K
在训练时就是使交叉熵损失函数最小。
2.实验部分
CLIP+CoOp模型在11个数据集上的平均分数,M=16代表训练的context token的长度为16,mid和end分别代表把CLASS放在中间位置和末尾位置,CSC代表使用的是Class-Specific Context方法。从图中可看出,CLIP+CoOp均比微调的CLIP效果好,并且使用unified context在one-shot时的性能就与Zero-shot CLIP的性能相匹配。
并且我们可在图中看出unified context要比CSC的效果好很多,作者推测,由于是few-shot,可以学到的东西有限,所以性能提升不明显。并且class token放在句子中间还是放在句子末尾位置对性能的影响不大。
3.总结
CoOp将人工设置的Prompt,变为一个learnable prompt,并在CLIP上取得了很不错的效果。但其学到的context连起来无法用正常语言进行解释,引人思考。