一、简介
题目: OpenCon: Open-world Contrastive Learning
会议: TMLR 2023
任务: 给定一个数据集,其中部分样本有标签(可认为它们属于已知类),其余样本无标签(可能属于已知类也可能属于未知类/新类),要求将无标签样本中属于已知类的样本正确分类,对属于新类别的样本进行聚类。
Note: 它与Open World Semi-Supervised Learning与Generalized Category Discovery讲的是同一个故事。
方法:
(1)采用Open World Semi-Supervised Learning或AutoNovel中的方法确定总类别数量(包括已知类和新类)。然后,在backbone后增加一个分类头,其大小等于总类别数量。
(2)采用对比损失进行网络优化。对比损失有3项,分别针对所有有标签的数据、所有无标签的数据、无标签数据中属于新类别的数据。
Note: 对于(1),原文中说是采用Open World Semi-Supervised Learning或Learning to Discover Novel Visual Categories via Deep Transfer Clustering中的方法,因为后者与AutoNovel是同一作者用的是同一方法,并且对AutoNovel我们已经做了解析,所以上面(1)中给出的是Open World Semi-Supervised Learning或AutoNovel。
二、详情
1. 常规的对比损失
L
ϕ
(
x
;
τ
,
P
(
x
)
,
N
(
x
)
)
=
−
1
∣
P
(
x
)
∣
∑
z
+
∈
P
(
x
)
log
exp
(
z
T
⋅
z
+
/
τ
)
∑
z
∗
∈
A
(
x
)
exp
(
z
T
⋅
z
∗
/
τ
)
\mathcal L_\phi(\textbf x;\tau,\mathcal P(\textbf x),\mathcal N(\textbf x))=-\frac{1}{|\mathcal P(\textbf x)|}\sum_{\textbf z^+\in\mathcal P(\textbf x)}\log\frac{\exp(\textbf z^T\cdot\textbf z^+/\tau)}{\sum_{\textbf z^*\in\mathcal A(\textbf x)}\exp(\textbf z^T\cdot\textbf z^*/\tau)}
Lϕ(x;τ,P(x),N(x))=−∣P(x)∣1z+∈P(x)∑log∑z∗∈A(x)exp(zT⋅z∗/τ)exp(zT⋅z+/τ)
其中
∑
z
∗
∈
A
(
x
)
exp
(
z
T
⋅
z
∗
/
τ
)
=
∑
z
+
∈
P
(
x
)
exp
(
z
T
⋅
z
+
/
τ
)
+
∑
z
−
∈
N
(
x
)
exp
(
z
T
⋅
z
−
/
τ
)
\sum_{\textbf z^*\in\mathcal A(\textbf x)}\exp(\textbf z^T\cdot\textbf z^*/\tau)=\sum_{\textbf z^+\in\mathcal P(\textbf x)}\exp(\textbf z^T\cdot\textbf z^+/\tau)+\sum_{\textbf z^-\in\mathcal N(\textbf x)}\exp(\textbf z^T\cdot\textbf z^-/\tau)
z∗∈A(x)∑exp(zT⋅z∗/τ)=z+∈P(x)∑exp(zT⋅z+/τ)+z−∈N(x)∑exp(zT⋅z−/τ)
其中,
x
\textbf x
x为样本,
z
=
ϕ
(
x
)
\textbf z=\phi(\textbf x)
z=ϕ(x)为提取出来的特征,
τ
\tau
τ是一个调节参数,
P
(
x
)
\mathcal P(\textbf x)
P(x)和
N
(
x
)
\mathcal N(\textbf x)
N(x)分别是
z
\textbf z
z的正集和负集。
Note: 为方便理解,我们稍微修改了文中对比损失部分公式,将 N ( x ) \mathcal N(\textbf x) N(x)改为 A ( x ) \mathcal A(\textbf x) A(x),这样 P ( x ) \mathcal P(\textbf x) P(x)与 N ( x ) \mathcal N(\textbf x) N(x)分别为正负集会更容易让人理解。
对比损失能够保证在特征空间中 z + \textbf z^+ z+更接近 z \textbf z z, z − \textbf z^- z−更远离 z \textbf z z,更重要的是它能使模型更关注如何区分 z \textbf z z与困难负样本。简单理解,假设 z \textbf z z为🐈, z + \textbf z^+ z+就是其它的🐈, z − \textbf z^- z−为🐘或🐅。对模型来说🐘与🐈相对容易区分,🐘可被认为简单负样本,🐅与🐈容易被模型分错,那么🐅可被认为是困难负样本。
鉴于对比损失的优良特性,作者定义了如下的损失函数:
其中, L n \mathcal L_n Ln、 L l \mathcal L_l Ll、 L u \mathcal L_u Lu分别为针对无标签数据中属于新类别的数据、所有有标签的数据、所有无标签的数据的对比损失。 λ n \lambda_n λn、 λ l \lambda_l λl、 λ u \lambda_u λu为对应的系数。
在分类已知并聚类已知的问题中,想要使用该损失函数,难题在于如何确定 P ( x ) \mathcal P(\textbf x) P(x)和 N ( x ) \mathcal N(\textbf x) N(x)。
2. 自监督对比损失
我们从最简单的自监督对比损失讲起,它是针对无标签样本的。损失函数如下:
L
ϕ
(
x
;
τ
u
,
P
u
(
x
)
,
N
u
(
x
)
)
=
−
1
∣
P
u
(
x
)
∣
∑
z
+
∈
P
u
(
x
)
log
exp
(
z
T
⋅
z
+
/
τ
u
)
∑
z
∗
∈
A
(
x
)
exp
(
z
T
⋅
z
∗
/
τ
u
)
\mathcal L_\phi(\textbf x;\tau_u,\mathcal P_u(\textbf x),\mathcal N_u(\textbf x))=-\frac{1}{|\mathcal P_u(\textbf x)|}\sum_{\textbf z^+\in\mathcal P_u(\textbf x)}\log\frac{\exp(\textbf z^T\cdot\textbf z^+/\tau_u)}{\sum_{\textbf z^*\in\mathcal A(\textbf x)}\exp(\textbf z^T\cdot\textbf z^*/\tau_u)}
Lϕ(x;τu,Pu(x),Nu(x))=−∣Pu(x)∣1z+∈Pu(x)∑log∑z∗∈A(x)exp(zT⋅z∗/τu)exp(zT⋅z+/τu)
我们将一个批次的无标签数据记为 B u {\mathcal B}_u Bu,对 B u {\mathcal B}_u Bu进行两次随机增强得到 B ~ u \tilde {\mathcal B}_u B~u( ∣ B ~ u ∣ = 2 ∣ B u ∣ |\tilde {\mathcal B}_u|=2|\mathcal B_u| ∣B~u∣=2∣Bu∣)。假设无标签样本为 x \textbf x x,由 x \textbf x x两次随机增强的样本为 x ′ \textbf x^{\prime} x′和 x ′ ′ \textbf x^{\prime\prime} x′′,对应的特征为 z ′ \textbf z^{\prime} z′和 z ′ ′ \textbf z^{\prime\prime} z′′, B ~ u \tilde {\mathcal B}_u B~u对应的特征集便是{ z ′ , A ( x ) \textbf z^\prime,\mathcal A(\textbf x) z′,A(x)}。
Note: 所有的损失都是在 B ~ u \tilde {\mathcal B}_u B~u上进行计算的,不再考虑 B u {\mathcal B}_u Bu。
由于样本无标签,所以只能进行自监督,于是,从 B ~ u \tilde {\mathcal B}_u B~u中选择一个 x ′ \textbf x^\prime x′,对应的特征为 z ′ \textbf z^\prime z′,那么 P u ( x ) {\mathcal P}_u(\textbf x) Pu(x)便只有一个元素,就是由 x \textbf x x增强的另一个样本 x ′ ′ \textbf x^{\prime\prime} x′′的特征 z ′ ′ \textbf z^{\prime\prime} z′′。 A ( x ) \mathcal A(\textbf x) A(x)就是除了 z ′ \textbf z^{\prime} z′之外的 B ~ u \tilde {\mathcal B}_u B~u的所有特征的集合。 N u ( x ) = A ( x ) − P u ( x ) \mathcal N_u(\textbf x)=\mathcal A(\textbf x)-\mathcal P_u(\textbf x) Nu(x)=A(x)−Pu(x),即 N u ( x ) \mathcal N_u(\textbf x) Nu(x)为 B u {\mathcal B}_u Bu去掉 x \textbf x x后所有样本两次随机增强后的特征集。
3. 有监督对比损失
接着,我们讲有监督对比损失,它是针对有标签样本的。我们同样对
B
l
{\mathcal B}_l
Bl(一个批次的有标签数据)进行两次随机增强得到
B
~
l
\tilde {\mathcal B}_l
B~l。相应的损失函数如下:
L
ϕ
(
x
;
τ
l
,
P
l
(
x
)
,
N
l
(
x
)
)
=
−
1
∣
P
l
(
x
)
∣
∑
z
+
∈
P
l
(
x
)
log
exp
(
z
T
⋅
z
+
/
τ
l
)
∑
z
∗
∈
A
(
x
)
exp
(
z
T
⋅
z
∗
/
τ
l
)
\mathcal L_\phi(\textbf x;\tau_l,\mathcal P_l(\textbf x),\mathcal N_l(\textbf x))=-\frac{1}{|\mathcal P_l(\textbf x)|}\sum_{\textbf z^+\in\mathcal P_l(\textbf x)}\log\frac{\exp(\textbf z^T\cdot\textbf z^+/\tau_l)}{\sum_{\textbf z^*\in\mathcal A(\textbf x)}\exp(\textbf z^T\cdot\textbf z^*/\tau_l)}
Lϕ(x;τl,Pl(x),Nl(x))=−∣Pl(x)∣1z+∈Pl(x)∑log∑z∗∈A(x)exp(zT⋅z∗/τl)exp(zT⋅z+/τl)
与自监督对比损失不同的是,有监督对比损失的 P l ( x ) \mathcal P_l(\textbf x) Pl(x)中的元素更多。对于 x ′ \textbf x^\prime x′来说,除 x ′ ′ \textbf x^{\prime\prime} x′′外,我们还希望与 x ′ \textbf x^\prime x′同标签的其它样本所提出的特征也能与 z ′ \textbf z^\prime z′接近。于是,这些特征连同 z ′ ′ \textbf z^{\prime\prime} z′′都应该归入 P l ( x ) \mathcal P_l(\textbf x) Pl(x)。相应地, N l ( x ) \mathcal N_l(\textbf x) Nl(x)就是 B ~ l \tilde {\mathcal B}_l B~l中与 x ′ \textbf x^\prime x′标签不同的所有样本所提特征的集合。 A ( x ) = A l ( x ) + N l ( x ) \mathcal A(\textbf x)=\mathcal A_l(\textbf x)+\mathcal N_l(\textbf x) A(x)=Al(x)+Nl(x)。
Note: 该用法与原文稍有差别,因为作者并未使用 x ′ ′ \textbf x^{\prime\prime} x′′的标签,而是使用 x ′ ′ \textbf x^{\prime\prime} x′′的预测标签。个人认为既然有标签数据都是有标签的对他们充分使用会让结果更可靠。
4. 新类对比损失
最后,讲新类对比损失,它是针对无标签样本集中属于新类的样本的。然而,要针对属于新类的样本,就要先挑选出来这些样本。作者使用了基于原型的分布外检测的方法。原型首先被随机初始化,之后采用moving-average style的方法进行更新。你可以简单的将原型视为每个类别的一个代表特征向量。原型的数量就是总类别的数量,即已知类+新类的数量。
A. 基于原型的分布外检测
将满足如下条件的无标签数据视作新类样本:
其中, μ \pmb{\mu} μ为原型, Y l \mathcal Y_l Yl是已知类的类别集合, λ \lambda λ为阈值。 λ \lambda λ是根据有标签数据确定的,采用p-percentile的方法。例如p=90时,有10%的样本会被认作新类。该等式的意思是说将样本 x i \textbf x_i xi与已知类的原型 μ \bf{\mu} μ做对比,与 ϕ ( x i ) \phi(\textbf x_i) ϕ(xi)越接近, μ j T ⋅ ϕ ( x i ) \pmb{\mu}_j^T\cdot\phi(\textbf x_i) μjT⋅ϕ(xi)就越大,那么 ϕ ( x i ) \phi(\textbf x_i) ϕ(xi)就越可能属于第 j j j类。如果 ϕ ( x i ) \phi(\textbf x_i) ϕ(xi)与最接近的原型的相似度都不大于 λ \lambda λ,就可以认为 x i \textbf x_i xi属于新类。
B. 正负集选择
首先还是对
B
n
{\mathcal B}_n
Bn(一个批次的新类数据)进行两次随机增强得到
B
~
n
\tilde {\mathcal B}_n
B~n。损失函数如下:
L
ϕ
(
x
;
τ
n
,
P
n
(
x
)
,
N
n
(
x
)
)
=
−
1
∣
P
n
(
x
)
∣
∑
z
+
∈
P
n
(
x
)
log
exp
(
z
T
⋅
z
+
/
τ
n
)
∑
z
∗
∈
A
(
x
)
exp
(
z
T
⋅
z
∗
/
τ
n
)
\mathcal L_\phi(\textbf x;\tau_n,\mathcal P_n(\textbf x),\mathcal N_n(\textbf x))=-\frac{1}{|\mathcal P_n(\textbf x)|}\sum_{\textbf z^+\in\mathcal P_n(\textbf x)}\log\frac{\exp(\textbf z^T\cdot\textbf z^+/\tau_n)}{\sum_{\textbf z^*\in\mathcal A(\textbf x)}\exp(\textbf z^T\cdot\textbf z^*/\tau_n)}
Lϕ(x;τn,Pn(x),Nn(x))=−∣Pn(x)∣1z+∈Pn(x)∑log∑z∗∈A(x)exp(zT⋅z∗/τn)exp(zT⋅z+/τn)之前说过,使用对比损失的关键在于
P
(
x
)
\mathcal P(\textbf x)
P(x)和
N
(
x
)
\mathcal N(\textbf x)
N(x)的选取。因为我们总共有
∣
Y
a
l
l
∣
|\mathcal Y_{all}|
∣Yall∣个原型,所以对于新类我们也可以使用如下公式来充分利用它们。
y
^
=
arg
max
j
∈
Y
a
l
l
μ
j
T
⋅
ϕ
(
x
)
\hat{y}=\arg\max_{j\in\mathcal Y_{all}}\pmb{\mu}_j^T\cdot\phi(\textbf x)
y^=argj∈YallmaxμjT⋅ϕ(x)
使用该公式可以预测出新类样本的类别,这样就可以按照有标签样本的方式进行正负集的选择。对于随机增强的 x ′ \textbf x^\prime x′来说, P n ( x ) \mathcal P_n(\textbf x) Pn(x)就应该是与 x ′ \textbf x^\prime x′的预测标签一致的那些样本的特征集合。相应地, N n ( x ) \mathcal N_n(\textbf x) Nn(x)就是那些与 x ′ \textbf x^\prime x′的预测标签不一致的样本的特征集合。 A ( x ) = A n ( x ) + N n ( x ) \mathcal A(\textbf x)=\mathcal A_n(\textbf x)+\mathcal N_n(\textbf x) A(x)=An(x)+Nn(x)。
C. 原型更新
因为原型是随机初始化的,效果肯定不好,所以需要随着训练进程而进行更新。更新当然要把原型朝向对应标签样本特征所在的位置更新。比如, μ j \pmb\mu_j μj就该往属于第 j j j类的样本的特征所在地移动。移动使用moving-average style的方法,规则如下:
简单来说,如果样本 x i \textbf x_i xi是有标签的并且标签是 b b b,那么 μ b \pmb\mu_b μb可直接往它的特征 z \textbf z z的方向移动,如果 x \textbf x x没有标签并通过A. 基于原型的分布外检测 检测为新类,再根据 arg max j ∈ Y n μ j T ⋅ z \arg\max_{j\in\mathcal Y_{n}}\pmb{\mu}_j^T\cdot \textbf z argmaxj∈YnμjT⋅z判断出 j = d j=d j=d,那么 μ d \pmb\mu_d μd就往 z \textbf z z的方向移动,否则,不移动。
至此,评估出总类别数量,并将 L O p e n C o n \mathcal L_{OpenCon} LOpenCon作为损失对网络进行优化后,便可以实现分类已知并聚类未知的目标。