【Novel Category Discovery】Open World Semi-Supervised Learning in ICLR 2022 个人理解

文章介绍了ICLR2022会议上关于OpenWorldSemi-SupervisedLearning的任务,即在部分样本有标签,部分无标签的数据集上进行分类和新类发现。方法包括使用SimCLR预训练网络,创建大型分类头以及应用三个损失函数:监督目标损失、成对目标损失和正则化项,以同步已知类与新类的学习,保持概率分布均匀,并确保样本正确分类和聚类。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、简介

题目: Open World Semi-Supervised Learning
会议: ICLR 2022
任务: 给定一个数据集,其中部分样本有标签(这里称其为已知类),其余样本无标签(可能属于已知类也可能属于未知类),要求将无标签样本中属于已知类的样本正确分类,对属于未知类的样本进行聚类或称发现新类。
Note: 这里的Open World Semi-Supervised Learning与Generalized Category Discovery讲的是同一件事情。
方法:
(1)使用SimCLR进行网络预训练以获得更好的特征表达。预训练是以自监督的方式在目标数据集的全部有标签和无标签数据上进行;
(2)通过在Backbone后接一个大的分类头来实现对已知类的分类和新类的发现。大的分类头是指SoftMax层的神经元个数远大于已知类的个数,这样已知类样本可以激活对应的神经元得到分类概率,而新类别则会激活其余神经元(并非其余的全部,因为初始的分类头较大);
(3)为了使网络能够达到(2)中所期望的效果,作者提出了三个损失:
  a. Supervised objective with uncertainty adaptive margin。用于控制网络对已知类的学习速度,使其不会过快,达到与新类的学习速度同步的目的;
  b. Pairwise objective。用于确保有标签样本能与同类别样本分到同一组,并使无标签样本与最近邻样本(可能是有标签的样本也可能是无标签的样本)分到同一组;
  c. Regularization term。用于确保概率预测结果不会集中在某个或某些类别上。

开放世界半监督学习
如上图所示,Open World Semi-Supervised Learning的目标是将属于已知类的无标签集中的样本正确分类,并将属于新类别的样本聚到一个新的组。

二、详情

1. 网络结构

在Backbone(比如ResNet-50)后加一个大的分类头。大的分类头是指SoftMax层的神经元个数远大于已知类的个数,这样已知类样本可以激活对应的神经元得到分类概率,而新类别则会激活其余神经元(并非其余的全部,因为初始的分类头较大)。

2. 损失函数设置

为达到划分已知类并发现新类的目的,作者设置了一个包含3个成分的损失函数:

其中, η 1 \eta_1 η1 η 2 \eta_2 η2为调节因子,作者均设置成了1。

A. Supervised objective with uncertainty adaptive margin

此项的目的为缩小已知类的类内方差与新类的类内方差的差距,以避免新类样本被分到已知类中。说白了就是不希望网络只学习有标签的数据,也关心一下无标签数据中新类的学习。

类内方差由不确定性衡量。目标函数如下:

其中, W W W为Backbone到分类头的权重, z z z为Backbone的输出(即所提取出的特征), Z l \mathcal Z_l Zl是有标签样本所提特征的集合, u ˉ \bar u uˉ为不确定性, λ \lambda λ为调节因子, s s s是控制交叉熵的一个参数。

Note: 经过理解和对比源码第95行,我发现此处 + λ u ˉ +\lambda\bar u +λuˉ应该是写错了,应改为 − λ u ˉ -\lambda\bar u λuˉ。具体原因后面说。

实际的交叉熵长这样:

所以对比前一个目标函数可知,作者主要是修改了有标签样本 z i z_i zi的交叉熵(具体说就是缩小了 z i z_i zi标签对应的SoftMax神经元上的值,注意只有与正确标签对应的预测概率值被缩小了,其余的预测概率值没有变化,此处说缩小是在 − λ u ˉ -\lambda\bar u λuˉ的情况下),以此实现通过不确定性 u ˉ \bar u uˉ来控制网络对有标签样本的学习速度的目的。不确定性 u ˉ \bar u uˉ由下式计算:

其中, D u D_u Du为无标签样本集, Pr \text{Pr} Pr为预测概率。简单来说,一个无标签样本的不确定性就是1减去该样本的预测概率的最大值, u ˉ \bar u uˉ就是所有无标签样本的不确定性的均值。

总结下来就是:

训练初期,对无标签样本的预测的不确定性高(因为预测概率并没有集中在一个类别上),则该损失函数可使已知类与新类的类内方差接近且都较高。(此处,类内方差与交叉熵的关系我并不理解,但是可以理解的是,对于作者提出的 L S \mathcal L_\text S LS来说, log ⁡ \log log后的 e x e x + a \frac{e^x}{e^x+a} ex+aex是个增函数, − log ⁡ ( ∗ ) -\log(*) log()是个减函数, W  ⁣ ⁣ ⋅  ⁣ z − λ u ˉ W\!\!\cdot\!z-\lambda\bar u Wzλuˉ使 x x x减小了, x ↓ x\downarrow x,则 e x e x + a ↓ \frac{e^x}{e^x+a}\downarrow ex+aex,则 − log ⁡ ( ∗ ) ↑ -\log(*)\uparrow log(),则 L S ↑ \mathcal L_\text S\uparrow LS,这样初期不确定性高的时候,模型将样本预测为已知类的损失会比较大,模型为了降低损失就会把概率更多的分配给新类);

训练后期,经过对无标签样本的学习,对无标签样本的预测的不确定性降低,此时认为新类簇已经形成,该损失函数可使模型更加关注对有标签样本的学习(不确定性为0时,则该损失函数退化为常规的交叉熵,这样就只关注有标签的样本了)。

B. Pairwise objective

此项的目的是确保有标签的样本能与同类别样本分到同一组,并使无标签的样本与最近邻样本(可能是有标签的样本也可能是无标签的样本)分到同一组。

目标函数如下:

其中, σ \sigma σ为SoftMax函数,如果 z i z_i zi是有标签的, z i ′ z^\prime_i zi是与 z i z_i zi同类别的随机的有标签样本,如果 z i z_i zi没有标签, z i ′ z^\prime_i zi是距离 z i z_i zi最近的一个样本(可能是有标签的也可能是无标签的)。

从该目标函数可以看出,作者使用余弦相似度衡量分类器对两个希望在同一组的样本的预测概率间的差别。如果分类器对 z i z_i zi z i ′ z^\prime_i zi的概率预测差别大,则损失大,反之,损失小。从而,分类器会向使 z i z_i zi z i ′ z^\prime_i zi接近的方向优化。

C. Regularization term

此项的目的是确保预测出来的概率分布接近先验概率分布,但实际往往没有该先验知识,此时改用Maximum entropy regularization来使预测出来的概率分布接近均匀分布。直白地说就是希望与新类对应的神经元能够被激活。

作者在文中描述的是KL散度(Kullback-Leibler divergence),定义如下:

其中, P ( y ) \mathcal P(y) P(y)为样本在各类别上的先验概率分布。KL散度就是在衡量预测概率的分布和先验概率分布间的差距。差距越大,损失越大。容易理解,该损失希望分类器的概率预测与先验概率分布接近。

遗憾的是,通常是无法事先获取先验概率分布的,所以作者使用最大熵替换KL散度作为损失项。熵的定义如下:
H = − ∑ i P ( i ) l o g P ( i ) H=-\sum_iP(i)logP(i) H=iP(i)logP(i)

其中, P ( i ) P(i) P(i)为对第 i i i个类别的概率预测值。在无任何先验知识的情况下,熵在每个 P ( i ) P(i) P(i)都相等,即均匀分布的情况下取最大值。最大熵是要在已知的约束条件下取熵最大的结果。但是,此处并没有任何约束,所以均匀分布就是作者所期望的最大熵。

实际运用熵作为损失项时,应该使 R = − H \mathcal R=-H R=H,从源码的第70行可以看出作者也是这么做的。因为熵越大,我们希望损失越小,这样分类器的概率预测会分布的更加均匀,从而使新类能够被发现。

在这里插入图片描述
如图所示,作者通过上述三个损失项分别实现同步已知类与新类学习速度、将有标签的同类样本分到一起并将无标签样本与最近邻样本分到一起、使概率预测更接近均匀分布的目的。

3. 预训练与微调

模型搭建好之后,可以对模型进行预训练,作者使用SimCLR进行网络预训练以获得更好的特征表达。预训练是以自监督的方式在目标数据集的全部有标签和无标签数据上进行。

为达到开放世界半监督学习的目的,预训练之后,作者使用上述损失函数在目标数据集上进行半监督的微调操作。

### STiL 方法概述 Semi-supervised Tabular-Image Learning (STiL) 是一种用于处理多模态数据的半监督学习方法[^1]。该方法旨在通过结合表格数据和图像数据来提升模型性能,特别是在标注数据有限的情况下。STiL 的核心目标是从不同模态的数据中提取任务相关信息并加以融合。 #### 多模态分类中的任务相关信息探索 在多模态分类场景下,任务相关信息通常分布在不同的数据源之间。STiL 方法通过设计特定机制,在训练过程中逐步识别哪些特征对于当前任务最为重要[^2]。具体而言: - **跨模态关联建模**:STiL 利用注意力机制捕获表格数据与图像数据之间的潜在关系。这种机制能够动态调整各模态的重要性权重,从而聚焦于最相关的部分[^3]。 - **自监督信号增强**:为了充分利用未标记样本的信息,STiL 引入了自监督学习策略。这些策略可以通过预测旋转角度、对比学习等方式生成额外的学习信号,进一步优化模型参数[^4]。 - **联合表示空间构建**:通过对齐两种模态的嵌入向量,STiL 创建了一个统一的任务相关表示空间。这使得即使某些模态缺失或质量较差时,模型仍能保持较高的鲁棒性和准确性[^5]。 以下是实现上述功能的一个简化代码框架: ```python import torch.nn as nn class STILModel(nn.Module): def __init__(self, tabular_dim, image_channels): super(STILModel, self).__init__() # 图像编码器初始化 self.image_encoder = ImageEncoder(image_channels) # 表格数据编码器初始化 self.tabular_encoder = TabularEncoder(tabular_dim) # 跨模态注意层 self.cross_modal_attention = CrossModalAttention() # 输出层定义 self.classifier = Classifier() def forward(self, table_data, image_data): img_features = self.image_encoder(image_data) tab_features = self.tabular_encoder(table_data) combined_features = self.cross_modal_attention(img_features, tab_features) output = self.classifier(combined_features) return output ``` 此代码展示了如何分别对图像和表格数据进行编码,并利用 `CrossModalAttention` 层完成两者间的交互操作[^6]。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Fulin_Gao

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值