【Novel Category Discovery】A Unified Objective for Novel Class Discovery in ICCV 2021 个人理解

一、简介

题目: A Unified Objective for Novel Class Discovery
会议: ICCV 2021
任务: 给定一个数据集,其中部分样本有标签(可认为它们属于已知类),其余样本无标签(可认为它们属于新类/未知类,未知类与已知类不重叠),要求模型保留对已知类的分类能力同时对无标签样本进行聚类,或称新类发现(Novel Category Discovery,NCD)。
Note: Open World Semi-Supervised LearningGeneralized Category DiscoveryParametric Classification for Generalized Category Discovery对NCD进行了扩展,他们允许无标签样本中包含属于已知类的数据。此外,该方法假设新类数量已知,这通常是不现实的,AutoNovel: Automatically Discovering and Learning Novel Visual CategoriesGeneralized Category Discovery都提出了解决方案。
方法:
(1)建立包含编码器(ResNet18)、有标签分类头(原型线性层)、无标签分类头(映射头MLP+线性层)的网络,将有标签分类头和无标签分类头的logits(未经SoftMax的输出)拼接(concat)起来,再经SoftMax层后由标准交叉熵优化;
(2)对一个图像随机增强出两个视图,有标签时两个视图与原图像标签相同,无标签时,通过Sinkhorn-Knopp algorithm求解伪标签,再将所得两个视图生成的伪标签互换来保证预测一致性;
(3)设置更多的无标签分类头和更大的无标签分类头输出单元以提升聚类精度。

如图,该方法(UNO)使用了一个统一的目标函数(标准交叉熵)。

二、详情

1. 网络搭建

UNO
如图,有标签数据 D l D^l Dl和无标签数据 D u D^u Du都会通过两次随机增强得到两个视图。之后输入网络,网络包含编码器(ResNet18)、有标签分类头(原型线性层)、无标签分类头(MLP+线性层),编码器和头都是共享的。

然后,连接有标签头( h h h)和无标签头( g g g)的logits输出,经SoftMax层得到后验概率,有了该概率即可与真实标签或伪标签一起通过交叉熵训练网络。标签需由zero-pad进行扩充才能在尺寸上与概率匹配,公式如下:
y = { [ y l , 0 C u ] , x ∈ D l [ 0 C l , y ^ ] , x ∈ D u y=\left\{\begin{aligned}[\pmb y^l,\pmb 0_{C^u}], \pmb x\in D^l \\ [\pmb 0_{C^l},\hat{\pmb y}], \pmb x\in D^u \end{aligned}\right. y={[yl,0Cu],xDl[0Cl,y^],xDu其中, y l \pmb y^l yl y ^ \hat{\pmb y} y^为有标签数据的标签和无标签数据的伪标签, 0 C u \pmb 0_{C^u} 0Cu 0 C l \pmb 0_{C^l} 0Cl为长度等于新类数量 C u C^u Cu和已知类数量 C l C^l Cl的全0向量。

有标签样本使用真实标签(GT),无标签样本使用Sinkhorn-Knopp algorithm求解的伪标签(pseudo-label,PL,下一节讲),由交叉熵标准进行优化训练。

2. 多视图伪标签生成

对于一个图像 x \pmb x x,其随机增强的两个视图为 v 1 \pmb v_1 v1 v 2 \pmb v_2 v2。若 x \pmb x x有标签,则 v 1 \pmb v_1 v1 v 2 \pmb v_2 v2的标签与 x \pmb x x的相同,均为 [ y l , 0 C u ] [\pmb y^l,\pmb 0_{C^u}] [yl,0Cu];若 x \pmb x x无标签,则需生成伪标签,由 v 1 \pmb v_1 v1生成 y ^ 1 \hat{\pmb y}_1 y^1,由 v 2 \pmb v_2 v2生成 y ^ 2 \hat{\pmb y}_2 y^2。为了鼓励模型对来自同一图像的两个视图预测一致,采用了如下交换预测任务(swapped prediction task):

其中, ℓ \ell 为标准交叉熵, y 1 \pmb y_1 y1 y 2 \pmb y_2 y2 [ 0 C l , y ^ 1 ] [\pmb 0_{C^l},\hat{\pmb y}_1] [0Cl,y^1] [ 0 C l , y ^ 2 ] [\pmb 0_{C^l},\hat{\pmb y}_2] [0Cl,y^2]。直白的说,就是把伪标签互换了。

y ^ 1 \hat{\pmb y}_1 y^1 y ^ 2 \hat{\pmb y}_2 y^2可由Sinkhorn-Knopp algorithm求解下式得出:

其中, Tr \text{Tr} Tr为迹, H \text{H} H为熵, ϵ > 0 \epsilon>0 ϵ>0是一个超参数, L = [ l g 1 , ⋯   , l g B ] \pmb L=[\pmb l_g^1,\cdots,\pmb l_g^B] L=[lg1,,lgB]是一个批次 B B B个视图经无标签头的logits输出, Y ^ = [ y ^ 1 , ⋯   , y ^ B ] T \hat{\pmb Y}=[\hat{\pmb y}_1,\cdots,\hat{\pmb y}_B]^T Y^=[y^1,,y^B]T是一个批次 B B B个视图的伪标签, Γ \Gamma Γ为transportation polytope,定义如下:

它能保证在一个批次内,对各未知类的数量分配更加均匀,约 B u C u \frac{B^u}{C^u} CuBu个, B u B^u Bu B B B个视图中来自无标签样本的那些视图的个数。

本质上, Tr ( Y L ) \text{Tr}(\pmb{YL}) Tr(YL)为了保证伪标签接近无标签分类头的logits输出, H ( Y ) \text{H}(\pmb{Y}) H(Y)也是为了保证预测的各未知类的数量较为均衡,即不会集中在某一个类上。

3. 多头聚类和过聚类

为了提升聚类精度,作者采用多头聚类和过聚类策略。过聚类头( o o o)与无标签分类头( g g g)结构相似,都是映射头MLP+线性层,只是它的线性层输出单元更多是 C u × m C^u\times m Cu×m个。至于多头,多的是 g g g o o o,分别有 n n n个,记为( g 1 , ⋯   , g n g_1,\cdots,g_n g1,,gn)和( o 1 , ⋯   , o n o_1,\cdots,o_n o1,,on)。

与之前的描述一样,每个 g i g_i gi的logits输出都与有标签分类头( h h h,只有一个)的logits输出连接并经SoftMax与真实标签或伪标签计算标准交叉熵进行优化训练。同样地,每个 o i o_i oi的logits输出也都与 h h h的logits输出连接并计算交叉熵,只不过真实标签和伪标签要将长度通过zero-pad扩充至 C l + C u × m C^l+C^u\times m Cl+Cu×m

至此,整体方法已介绍完毕。之后先在有标签数据上进行预训练使网络具有初始权重,然后在完整数据(有标签+无标签)上进行训练,训练过程中各批次无标签样本的伪标签是当场求解的,所以作者特意强调该方法可用于数据按序列到来的情况。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Fulin_Gao

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值