【Novel Category Discovery】Open World Semi-Supervised Learning in ICLR 2022 个人理解

一、简介

题目: Open World Semi-Supervised Learning
会议: ICLR 2022
任务: 给定一个数据集,其中部分样本有标签(这里称其为已知类),其余样本无标签(可能属于已知类也可能属于未知类),要求将无标签样本中属于已知类的样本正确分类,对属于未知类的样本进行聚类或称发现新类。
Note: 这里的Open World Semi-Supervised Learning与Generalized Category Discovery讲的是同一件事情。
方法:
(1)使用SimCLR进行网络预训练以获得更好的特征表达。预训练是以自监督的方式在目标数据集的全部有标签和无标签数据上进行;
(2)通过在Backbone后接一个大的分类头来实现对已知类的分类和新类的发现。大的分类头是指SoftMax层的神经元个数远大于已知类的个数,这样已知类样本可以激活对应的神经元得到分类概率,而新类别则会激活其余神经元(并非其余的全部,因为初始的分类头较大);
(3)为了使网络能够达到(2)中所期望的效果,作者提出了三个损失:
  a. Supervised objective with uncertainty adaptive margin。用于控制网络对已知类的学习速度,使其不会过快,达到与新类的学习速度同步的目的;
  b. Pairwise objective。用于确保有标签样本能与同类别样本分到同一组,并使无标签样本与最近邻样本(可能是有标签的样本也可能是无标签的样本)分到同一组;
  c. Regularization term。用于确保概率预测结果不会集中在某个或某些类别上。

开放世界半监督学习
如上图所示,Open World Semi-Supervised Learning的目标是将属于已知类的无标签集中的样本正确分类,并将属于新类别的样本聚到一个新的组。

二、详情

1. 网络结构

在Backbone(比如ResNet-50)后加一个大的分类头。大的分类头是指SoftMax层的神经元个数远大于已知类的个数,这样已知类样本可以激活对应的神经元得到分类概率,而新类别则会激活其余神经元(并非其余的全部,因为初始的分类头较大)。

2. 损失函数设置

为达到划分已知类并发现新类的目的,作者设置了一个包含3个成分的损失函数:

其中, η 1 \eta_1 η1 η 2 \eta_2 η2为调节因子,作者均设置成了1。

A. Supervised objective with uncertainty adaptive margin

此项的目的为缩小已知类的类内方差与新类的类内方差的差距,以避免新类样本被分到已知类中。说白了就是不希望网络只学习有标签的数据,也关心一下无标签数据中新类的学习。

类内方差由不确定性衡量。目标函数如下:

其中, W W W为Backbone到分类头的权重, z z z为Backbone的输出(即所提取出的特征), Z l \mathcal Z_l Zl是有标签样本所提特征的集合, u ˉ \bar u uˉ为不确定性, λ \lambda λ为调节因子, s s s是控制交叉熵的一个参数。

Note: 经过理解和对比源码第95行,我发现此处 + λ u ˉ +\lambda\bar u +λuˉ应该是写错了,应改为 − λ u ˉ -\lambda\bar u λuˉ。具体原因后面说。

实际的交叉熵长这样:

所以对比前一个目标函数可知,作者主要是修改了有标签样本 z i z_i zi的交叉熵(具体说就是缩小了 z i z_i zi标签对应的SoftMax神经元上的值,注意只有与正确标签对应的预测概率值被缩小了,其余的预测概率值没有变化,此处说缩小是在 − λ u ˉ -\lambda\bar u λuˉ的情况下),以此实现通过不确定性 u ˉ \bar u uˉ来控制网络对有标签样本的学习速度的目的。不确定性 u ˉ \bar u uˉ由下式计算:

其中, D u D_u Du为无标签样本集, Pr \text{Pr} Pr为预测概率。简单来说,一个无标签样本的不确定性就是1减去该样本的预测概率的最大值, u ˉ \bar u uˉ就是所有无标签样本的不确定性的均值。

总结下来就是:

训练初期,对无标签样本的预测的不确定性高(因为预测概率并没有集中在一个类别上),则该损失函数可使已知类与新类的类内方差接近且都较高。(此处,类内方差与交叉熵的关系我并不理解,但是可以理解的是,对于作者提出的 L S \mathcal L_\text S LS来说, log ⁡ \log log后的 e x e x + a \frac{e^x}{e^x+a} ex+aex是个增函数, − log ⁡ ( ∗ ) -\log(*) log()是个减函数, W  ⁣ ⁣ ⋅  ⁣ z − λ u ˉ W\!\!\cdot\!z-\lambda\bar u Wzλuˉ使 x x x减小了, x ↓ x\downarrow x,则 e x e x + a ↓ \frac{e^x}{e^x+a}\downarrow ex+aex,则 − log ⁡ ( ∗ ) ↑ -\log(*)\uparrow log(),则 L S ↑ \mathcal L_\text S\uparrow LS,这样初期不确定性高的时候,模型将样本预测为已知类的损失会比较大,模型为了降低损失就会把概率更多的分配给新类);

训练后期,经过对无标签样本的学习,对无标签样本的预测的不确定性降低,此时认为新类簇已经形成,该损失函数可使模型更加关注对有标签样本的学习(不确定性为0时,则该损失函数退化为常规的交叉熵,这样就只关注有标签的样本了)。

B. Pairwise objective

此项的目的是确保有标签的样本能与同类别样本分到同一组,并使无标签的样本与最近邻样本(可能是有标签的样本也可能是无标签的样本)分到同一组。

目标函数如下:

其中, σ \sigma σ为SoftMax函数,如果 z i z_i zi是有标签的, z i ′ z^\prime_i zi是与 z i z_i zi同类别的随机的有标签样本,如果 z i z_i zi没有标签, z i ′ z^\prime_i zi是距离 z i z_i zi最近的一个样本(可能是有标签的也可能是无标签的)。

从该目标函数可以看出,作者使用余弦相似度衡量分类器对两个希望在同一组的样本的预测概率间的差别。如果分类器对 z i z_i zi z i ′ z^\prime_i zi的概率预测差别大,则损失大,反之,损失小。从而,分类器会向使 z i z_i zi z i ′ z^\prime_i zi接近的方向优化。

C. Regularization term

此项的目的是确保预测出来的概率分布接近先验概率分布,但实际往往没有该先验知识,此时改用Maximum entropy regularization来使预测出来的概率分布接近均匀分布。直白地说就是希望与新类对应的神经元能够被激活。

作者在文中描述的是KL散度(Kullback-Leibler divergence),定义如下:

其中, P ( y ) \mathcal P(y) P(y)为样本在各类别上的先验概率分布。KL散度就是在衡量预测概率的分布和先验概率分布间的差距。差距越大,损失越大。容易理解,该损失希望分类器的概率预测与先验概率分布接近。

遗憾的是,通常是无法事先获取先验概率分布的,所以作者使用最大熵替换KL散度作为损失项。熵的定义如下:
H = − ∑ i P ( i ) l o g P ( i ) H=-\sum_iP(i)logP(i) H=iP(i)logP(i)

其中, P ( i ) P(i) P(i)为对第 i i i个类别的概率预测值。在无任何先验知识的情况下,熵在每个 P ( i ) P(i) P(i)都相等,即均匀分布的情况下取最大值。最大熵是要在已知的约束条件下取熵最大的结果。但是,此处并没有任何约束,所以均匀分布就是作者所期望的最大熵。

实际运用熵作为损失项时,应该使 R = − H \mathcal R=-H R=H,从源码的第70行可以看出作者也是这么做的。因为熵越大,我们希望损失越小,这样分类器的概率预测会分布的更加均匀,从而使新类能够被发现。

在这里插入图片描述
如图所示,作者通过上述三个损失项分别实现同步已知类与新类学习速度、将有标签的同类样本分到一起并将无标签样本与最近邻样本分到一起、使概率预测更接近均匀分布的目的。

3. 预训练与微调

模型搭建好之后,可以对模型进行预训练,作者使用SimCLR进行网络预训练以获得更好的特征表达。预训练是以自监督的方式在目标数据集的全部有标签和无标签数据上进行。

为达到开放世界半监督学习的目的,预训练之后,作者使用上述损失函数在目标数据集上进行半监督的微调操作。

  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Fulin_Gao

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值