一、简介
题目: Parametric Classification for Generalized Category Discovery: A Baseline Study
来源: arXiv
任务: 给定一个数据集,其中部分样本有标签(这里称其为已知类),其余样本无标签(可能属于已知类也可能属于未知类),要求将无标签样本中属于已知类的样本正确分类,对属于未知类的样本进行聚类。
方法:
(1)使用vision transformer(ViT)作为backbone进行特征提取。backbone由自监督权重(self-supervised ImageNet weights)赋初值;
(2)之后结合有监督对比损失、自监督对比损失、自蒸馏损失(self-distillation)、标准交叉熵损失对ViT的最后一个宏块(block)和分类头进行联合微调。
Note: 之前讲的Generalized Category Discovery(GCD)和Open World Semi-Supervised Learning也是解决该问题的,该工作所用的方法(1)与前者基本一致,(2)中自蒸馏损失与后者最大熵部分基本一致。另外,该方法假设新类数量已知,其实这是不现实的,AutoNovel: Automatically Discovering and Learning Novel Visual Categories和GCD都提出了解决方案。
如图,该工作(SimGCD)采用监督对比损失+自监督对比损失+参数分类(自蒸馏损失、标准交叉熵损失)的优化策略。
二、详情
1. 参数分类失效原因分析
作者从以下三个方面对分类器直接用于广义类发现问题失效的原因进行了分析:
(1)应该在哪一层特征的基础上构建分类器;
(2)是否应该联合优化特征提取器和分类器;
(3)对于旧类的预测偏向问题。
通过在不同的标签质量(仅有已知类的标签
→
\rightarrow
→self-label伪标签
→
\rightarrow
→self-distillation伪标签
→
\rightarrow
→已知类未知类均有标签)设置下,对比UNO+(联合优化的策略,该方法使用self-label的伪标签生成策略)和GCD(独立/解耦的优化策略),作者得出的结论如下:
(1)相比projector(相当于backbone后的又一个特征提取器),在backbone后直接构建分类器更合适。并且,伪标签的质量越高,则广义类发现性能越好;
(2)在有高质量伪标签的情况下,相比独立优化,联合优化特征提取器和分类器更合适;
(3)UNO+和GCD都存在倾向于将未知类分到已知类的问题。
于是,作者采取了综合最优的方案:在backbone后直接构建分类器、构造损失联合优化特征提取器和分类器、增加正则化项限制对已知类的倾向。
2. 特征提取
该部分与GCD相应部分基本一致,GCD认为直接对图像数据进行分类或聚类效果必然不理想,需要先提取出空间性质较好的特征。GCD作者表示ViT作为backbone对近邻分类器十分友好,因此将ViT选作backbone。
从上图可以看出,与ResNet50(DINO)提取的特征相比,使用self-supervised ImageNet weights初始化的ViT(DINO)的类内距离更小、类间距离更大。
然后,为了增强ViT的表示学习(特征提取)能力,作者使用目标数据集中的数据对特征提取器进行了微调。主要采取的是对比学习的损失优化策略,将全部样本(包括有标签和无标签的全部数据)视为无标签样本使用如下自监督对比损失:

其中, z i \pmb{z}_i zi和 z i ′ \pmb{z}^\prime_i zi′分别是同一图片随机增强后的两个视图的特征表达, z i = g ( f ( x i ) ) \pmb{z}_i=g(f(\pmb x_i)) zi=g(f(xi)), f ( ∗ ) f(*) f(∗)为backbone(ViT), g ( ∗ ) g(*) g(∗)为projection head(projector, 如MLP)。有标签样本使用如下有监督对比损失:

其中, N ( i ) \mathcal{N}(i) N(i)表示与 z i \pmb{z}_i zi有相同类标签的样本索引。
直白来说,这两个对比损失就是希望来自同样本(如随机增强的两个样本)或同类(有相同真实标签的两个样本)的特征尽量接近。于是,形成如下用于优化特征提取器的损失函数:

其中,前一项用于优化全部图像(不使用标签),后一项用于优化有标签图像。
通过该损失函数仅优化特征提取器(不与分类器联合优化)的结果如ViT(Ours)所示,显然,它具有更好的特征空间表达。不过,之前1. 参数分类失效原因分析中说过,在有高质量伪标签的情况下,联合优化更合适,即对ViT的最后一个宏块和分类器一起优化会有更好的结果。所以这里先不优化,而是先生成高质量的伪标签,使用 L r e p \mathcal L_{rep} Lrep连同分类器的相关损失一起优化。
3. 参数分类
无标签的那部分数据无法直接用于优化分类器,于是作者采用了self-distillation的方法生成伪标签并据此训练分类器。其实就是对来自同一图像的两个随机增强后的图像,选择其中一个遮挡一部分,之后让被遮挡的与未遮挡的特征接近同一个原型。
图片来自self-distillation原文,图中
z
\pmb z
z应为该工作的
h
\pmb h
h,
p
+
\pmb p^+
p+应为该工作的
q
′
\pmb q^{\prime}
q′。
首先,对于一个图像 x i \pmb x_i xi进行两次随机增强,得到两个视图分别为锚(anchor)视图和目标(target)视图,定义锚视图特征为 h i = f ( x i a n c h o r ) \pmb h_i=f(\pmb x_i^{anchor}) hi=f(xianchor)。 f f f为ViT的backbone,所以分类器是直接跟在backbone后的,与最初的策略一致。
之后,随机初始化 K K K个特征原型(prototypes,大小应该与图中 z = f θ ( x ) \pmb z=f_\theta(\pmb x) z=fθ(x)一致,与该工作的 h \pmb h h一致), K K K为总类别数量=旧/已知类数量+新/未知类数量(这里假设新类数量是已知的),原型定义为 C = { c 1 , ⋯ , c K } \mathcal{C}=\{c_1,\cdots,c_K\} C={c1,⋯,cK}。
self-distillation会对锚视图进行遮挡,然后根据 h i \pmb h_i hi与 C \mathcal{C} C中各原型的余弦相似度生成软伪标签(硬标签是指one-hot形式的标签),如下式:

其中, p i ( k ) \pmb p^{(k)}_i pi(k)为 h i \pmb h_i hi在第 k k k个类上生成的软伪标签值,即属于类别 k k k的概率值。它是通过余弦相似度后经SoftMax函数得出的。同样地,目标视图也有软伪标签 q i ′ \pmb q_i^{\prime} qi′,目标视图没有被遮挡,并且 q i ′ \pmb q_i^{\prime} qi′比 p i \pmb p_i pi具有更自信的低熵,直白的说,就是概率值集中在一个类别上更接近one-hot的形式。
由上式可知,与某个类的原型越接近,则在该类上的概率或软伪标签值就越高。显然,我们希望锚视图和目标视图的软标签
p
i
\pmb p_i
pi和
q
i
′
\pmb q_i^{\prime}
qi′接近,并且希望用
q
i
′
\pmb q_i^{\prime}
qi′来引导
p
i
\pmb p_i
pi获得更自信的低熵预测。所以,可用交叉熵定义如下损失:
1
∣
B
∣
∑
i
∈
B
ℓ
(
q
i
′
,
p
i
)
\frac{1}{|B|}\sum_{i\in B}\ell(\pmb q_i^{\prime},\pmb p_i)
∣B∣1i∈B∑ℓ(qi′,pi)其中,
ℓ
\ell
ℓ为标准交叉熵。为了让预测不会偏向已知类,即让更多的概率分配给新类,作者引入最大熵正则项,如下:
H
(
p
ˉ
)
=
−
∑
k
p
ˉ
(
k
)
log
p
ˉ
(
k
)
p
ˉ
=
1
2
∣
B
∣
∑
i
∈
B
(
p
i
+
p
i
′
)
H(\bar{\pmb p})=-\sum_k\bar{\pmb{p}}^{(k)}\log\bar{\pmb p}^{(k)}\\ \bar{\pmb p}=\frac{1}{2|B|}\sum_{i\in B}(\pmb p_i+\pmb p_i^{\prime})
H(pˉ)=−k∑pˉ(k)logpˉ(k)pˉ=2∣B∣1i∈B∑(pi+pi′)其中,
p
i
′
\pmb p_i^{\prime}
pi′为
x
i
\pmb x_i
xi的另一个随机增强的锚图像的伪软标签。在上式中,各类概率呈均匀分布时
H
(
p
ˉ
)
H(\bar{\pmb p})
H(pˉ)最大,
−
H
(
p
ˉ
)
-H(\bar{\pmb p})
−H(pˉ)最小,所以该项可保证概率分配更加均匀,不会偏向已知类。于是,综合伪软标签和最大熵形成如下无标签分类目标:

前一项是希望锚视图和目标视图的软伪标签接近,后一项是希望概率分配更均匀一些,不要偏向已知类。
此外,由于还有一些有标签数据存在,它们也可以用经典交叉熵进行优化:

其中, y i \pmb y_i yi为 x i \pmb x_i xi的真实one-hot标签。于是,形成如下用于优化分类器的损失函数:

由于,在1. 参数分类失效原因分析中已说明作者发现在存在高质量伪标签情况下,联合优化是对广义类发现有帮助的,所以作者综合 L r e p + L c l s \mathcal{L}_{rep}+\mathcal{L}_{cls} Lrep+Lcls来联合优化ViT的最后一个宏块和分类器,而不是独立或分别优化。
综上,SimGCD通过有监督对比损失使得同类特征会更加接近,否则更远;通过自监督对比损失使得来自同一样本的特征更加接近,否则更远;通过自蒸馏损失使得来自同一样本的特征更加接近并使得预测不会偏向已知类;通过标准交叉熵损失使得模型完成对已知类的学习。如此,SimGCD便实现了广义类发现任务。