Robust Semi-Supervised Learning for Self-learning Open-World Classes(ICDM 2023)

自学习开放世界类别的鲁棒半监督学习(ICDM 2023)

Wenjuan Xi, Xin Song, Weili Guo, Yang Yang

南京理工大学,百度人才智库,香港理工大学

paper:https://arxiv.org/pdf/2401.07551.pdf

code: https://github.com/njustkmg/SSOC

@inproceedings{SSOC,
  author       = {Wenjuan Xi and
                  Xin Song and
                  Weili Guo and
                  Yang Yang},
  title        = {Robust Semi-Supervised Learning for Self-learning Open-World Classes},
  booktitle    = {ICDM},
  address      = {Shanghai China},
  pages        = {2374-8486},
  year         = {2023}
}

摘要

现有的半监督学习(SSL)方法假设标记和未标记数据共享相同的类别空间。然而,在实际应用中,未标记数据中总是包含标记集中没有的类别,这可能导致已知类别的分类性能下降。因此,研究人员提出了开放世界SSL方法来处理未标记数据中存在多个未知类别的情况,旨在准确分类已知类别的同时细分区分不同的未知类别。为了应对这一挑战,在本文中,我们提出了一种新颖的自学习开放世界类别的SSL方法(SSOC),它可以明确地自学习多个未知类别。具体而言,SSOC首先为已知类别和未知类别定义类中心标记,并利用交叉注意力机制根据所有样本自主学习标记表示。为了有效发现新类别,SSOC还设计了一种成对相似性损失,除了熵损失,它可以智能地利用未标记数据中的信息,通过实例的预测和关系来发现新类别。广泛的实验证明,SSOC在多个常用分类基准测试中优于最先进的基线方法。具体而言,在具有90%新颖比例的ImageNet-100数据集上,SSOC取得了显著的22%改进。

1、引言

随着机器学习的发展,深度学习在视觉、文本、语音等多个领域取得了显著成就[68,69,70,71]。初期的监督学习依赖于大量昂贵的标记数据进行模型训练,为了减少成本,SSL通过充分利用大量无标记数据的信息,能够在标记数据不足的情况下达到与监督学习相媲美的性能。然而,几乎所有的SSL方法都基于一个的默认假设:未标记数据与标记数据来自相同的分布,不会出现模型没有见过的类别,如Fig1第一行。这使得它们在封闭世界的场景中得到了广泛应用,但在现实世界的情况下,却会导致严重的性能下降[4,5,6,7]。例如,在病理图像分析中,一些病变的组织切片图像可能来自人们未知的疾病,并且不同未知疾病的差异较大,这就要求模型能够区分未知的病理图像类别。在网络安全领域,安全专家会使用模型对软件进行恶意检测,同时希望模型能够区分新的恶意软件类型。因此,有必要研究一种更包容、用途更广的开放世界方法。

近期,有研究学者提出了开放世界SSL的概念[1]。该设定允许未标记数据集中存在标记集中未出现过的类别样本,也就是未知类,如Fig1第三行。其目标是能够同时对已知类和未知类进行分类。此前,开放集SSL方法已经考虑到了开放场景的设定[54,55]。然而,与本文关注的开放世界SSL不同的是,开放集SSL方法仅简单地拒绝未知类样本,以防止无标记数据集中的未知类样本对已知类的分类性能造成危害,其主要任务仍然是分类已知类,如图1的第二行。类似地,新类发现方法[6,7,8]也考虑到了开放场景,但其假设未标记数据集中只包含未知类样本,并且只关注未知类的聚类,而不考虑模型在已知类上的性能。与以上两种方法相比,开放世界SSL是一个更具挑战性的问题,解决该问题的关键在于如何使模型能够更好地学习多个未知类,并确保已知类的分类性能。

开放世界SSL问题引起了广泛的关注,目前已经提出了一些解决该问题的方法[1,2]。这些方法主要从损失函数的角度出发,采用有助于未知类学习的优化目标,并设计了不确定性机制或自适应阈值来缓解学习过程中的类别不平衡现象。虽然这些方法在开放世界SSL中取得了出色的成绩,但它们仍然使用传统的分类模型,面对多个未知类时很容易学习到有偏差的决策边界。此外,简单的特征和标签映射更多地依赖于类别的统计信息和特征分布,导致模型难以真正理解类别的概念,在面对未知类别时通常表现出较差的鲁棒性。因此,目前还缺乏一种能够全面理解类别概念、有效区分类别间差异、挖掘数据中隐藏的模式和结构的开放世界SSL方法。

为此,本文提出了一种新颖的SSL方法,名为Self-Supervised Open-World Class方法,旨在显示地自学习多个未知类别。具体而言,我们首先初始化已知和未知类别的类原型(类中心)表示,然后利用交叉注意力机制结合数据特征,迭代地学习类中心的表示方法,以实现对类别信息的显示建模。为了辅助多个新类的学习,我们选择置信的未标记样本来约束熵损失,并同时使用成对相似性损失来挖掘未标记数据中的信息,以约束实例自身表示和预测层面的一致性。SSOC的模型架构不仅能够保证已知类别的分类性能,还注重学习类别之间的差异性和相似性。通过学习类原型的特征表示,SSOC能够执行有明确物理意义的分类操作,这为探索可解释性学习方法在开放世界SSL任务中的应用提供了新的思路。

综上所述,我们的工作具有以下贡献:(1)我们提出了一种新颖的自学习开放世界类别的SSL方法。该方法利用交叉注意力机制显示地建模类别概念,能够自主地学习多个未知类别。(2)我们设计了成对相似性损失,该损失能够智能地利用未标记数据中的信息,通过实例的预测和关系发现新的类别。(3)我们在CIFAR-10、CIFAR-100、ImageNet-100以及不同数据划分上进行了实验,证明了SSOC方法的有效性,并且在标记数据缺乏和新类数量众多的情景下,证明了SSOC出色的鲁棒性。

2、相关工作

开放世界SSL与SSL、新类检测、开放集识别密切相关。在本节中,我们总结了这些研究方向的异同,并调研了它们的发展历程。

2.1 半监督学习

SSL的目标是利用大量未标记数据和少量标记数据提高模型的学习性能,从而解决标记成本高昂的问题。近年来,深度SSL迅速发展,并可分为以下几个类别[1,2]:一致性正则化方法、伪标记方法、生成式方法、基于图的方法和混合方法等。其中,混合方法通常结合多种主流思想。例如,MixMatch[3]结合了一致性正则化和伪标记方法,通过对未标记样本进行随机数据增强并MixUp[4]伪标签来挖掘未标记数据中的信息;而FixMatch[5]则使用强增强图像和预测的弱增强图像的伪标签来学习一致性目标。尽管现有的SSL方法取得了巨大成功,但它们通常假设未标记数据集中只包含标记数据集中出现过的样本。一旦这个假设不成立,SSL方法会将未知样本与已知类别混为一谈。这不仅导致模型无法识别未知类别,还会因为未知类别的分布与已知类别差异较大而降低模型在已知类别上的性能。因此,传统的SSL方法无法应对开放世界问题。

2.2 开放集检测与开放场景半监督学习

开放集识别(Open set recognition,OSR)拓宽了传统的封闭环境设定,假设测试时会出现未知情景[54,55]。一个鲁棒的OSR模型应该能够正确分类已知类别并识别未知类别。目前的OSR方法可以分为判别式模型和生成式模型[54]。判别式模型通常采用阈值设计或调整预测概率分布的方法来拒绝预测概率较低的未知类别样本[56,57,58,59],例如,方法[56]中提出用OpenMax层替代SoftMax层,利用Weibull分布拟合激活向量以估计样本属于未知类别的概率。生成式模型旨在通过生成未知类别样本来欺骗鉴别器,从而实现对抗学习[60,61,62,63]。

开放集SSL将开放场景引入 SSL中,假设未标记数据集中可能包含未知类别的样本,但测试集仅包含已知类别,旨在减少未知类别对已知类别学习的负面影响,提高模型的鲁棒性。近年来,涌现了许多开放集SSL方法,其核心思想主要集中在降低未知类别样本的权重[64,66]或有选择地使用未知类别样本[65,67],例如DS3L[64]和Robust SSL[66]通过降低未知类别样本的权重来减少未知类别的干扰,UASD[65]则提出了一种不确定性感知的自我蒸馏方法,以防止模型对未知类别过于自信。然而,这些方法仅关注已知类别的性能,无法区分未标记数据中可能存在的多个未知类别。因此,开放集SSL方法无法解决开放世界SSL问题。

2.3 新类发现

新类发现场景属于弱监督学习[6,7,8],它使用一个带标签的数据集和一个无标签的数据集进行训练。与SSL不同的是,新类发现(NCD)假设无标签训练集和测试集中的样本都来自未知类别,并旨在能够对未知类别进行聚类。最初的NCD方法通常采用两阶段学习方式,首先在带标签的数据集上学习先验知识,然后使用类似于迁移学习的方法来聚类无标签集中的未知类别。例如,DTC[9]提出了一种深度转移聚类方法,通过利用已知类别的先验知识增强对未知类别的表示,并学习类别原型来完成未知类别的聚类。RankStats[10]则提出同时使用带标签和无标签数据进行训练可以减少表示偏差,然后利用秩统计方法将先验知识转移到未知类别。然而,与本文关注的开放世界SSL场景不同,NCD方法仅关注未知类别的聚类效果,忽视了已知类别的分类任务。因此,当测试集中同时存在已知类别和未知类别时,NCD方法难以获得出色的综合性能。

2.4 开放世界半监督学习

工作[1]提出了开放世界SSL设定,该设定假设未标记数据集中可能存在多个未知类别,旨在既能对已知类别进行分类,又能发现多个未知类别。相较于先前的研究工作,开放世界SSL场景更接近实际情况,但目前仍处于起步阶段。ORCA[1]是解决该问题的第一个端到端的深度学习框架,它构建了一种不确定性自适应边缘机制来增强对未知类别的学习,并取得了出色的结果。NACH[2]设计了自适应阈值来平衡已知类别和未知类别的学习,并提出了一种新的分类损失函数来帮助模型学习未知类别,取得了比ORCA更好的性能。

3、前置定义

在开放世界的半监督学习(SSL)背景下,训练集包括一个带有$\mathcal{M}$个标记样本的标记数据集$\mathcal{D}^l={\{(x_i, y_i)\}}_{i=1}^\mathcal{M}$和一个包含$\mathcal{N}$个未标记样本的未标记数据集$\mathcal{D}^u={\{x_i\}}_{i=1}^\mathcal{N}$,其中$x \in \mathbb{R}^D$$D$表示图像输入的维度。我们定义$\mathcal{C}^l$为出现在$\mathcal{D}^l$中的类别集合,$\mathcal{C}^u$为出现在$\mathcal{D}^u$中的类别集合。在标记数据集$\mathcal{D}^l$中,$y \in \mathcal{C}^l={\{1,...,|\mathcal{C}^l|\}}$,而在未标记数据集$\mathcal{D}^u$中,x属于$\mathcal{C}^u$中的某个类别。假设$\mathcal{C}^l \ne \mathcal{C}^u$$\mathcal{C}^l \cap \mathcal{C}^u \ne \emptyset$,我们定义$\mathcal{C}^s=\mathcal{C}^l \cap \mathcal{C}^u$为已知类别集合,$\mathcal{C}^n=\mathcal{C}^u \setminus \mathcal{C}^l$为未知类别集合。$S=|\mathcal{C}^s|$$N=|\mathcal{C}^n|$分别表示已知类别和未知类别的数量。

在先前提到的与开放世界SSL相关的方法中,SSL中没有新类的概念,并默认情况下假设$\mathcal{C}^l=\mathcal{C}^u$,而NCD则假设未标记数据不包含任何已知类别,即$\mathcal{C}^l \cap \mathcal{C}^u=\emptyset$。因此,开放世界SSL在本质上更具挑战性。

4、问题设置

在本节中,我们首先介绍了自学习开放世界类别的模型架构,然后形式化了有助于发现多个未知类的学习目标,并在最后给出了整体算法流程。

4.1 自学习开放世界类别

从形式化的问题定义中可以了解到,开放世界SSL问题的关键在于如何合理利用未标记数据集中的未知类样本。SSOC的主要思想是显示地自学习开放世界类别,不管是已知类还是未知类,也就是说我们希望学习到类别在与图像特征维度相同下的表示方法。仅仅获取类原型表示并不是一个困难的问题,先前的许多无监督聚类方法就可以将数据分为几个簇,并计算出每个簇的中心,但这些方法没有利用任何标记信息,并且只是单纯地优化特征嵌入模型来实现好的聚类性能,并不会将类中心纳入学习过程。在开放世界SSL训练过程中,数据以批次的方式输入模型,我们希望模型可以尽可能地发现每个批次中的类别信息并对类中心实施动态调整,在反向传播时应用明确的标记信息和潜在的样本相似关系优化类中心和图像的特征表示。为了实现这一功能,SSOC的核心模块使用交叉注意力机制,这也是SSOC除特征提取器外唯一需要的包含参数的网络层。

交叉注意力机制是一种能够捕捉两个序列之间相关性特征的方法。它最早出现在Transformer[72]中,用于在解码器部分将输入序列和编码器的输出序列进行融合,以获取与解码器当前位置相关的编码器信息。交叉注意力机制对于序列建模和自然语言处理任务具有重要贡献,例如在图像文本分类中用于合并多模态输入序列[73],在机器翻译中捕捉序列中远距离位置之间的依赖关系[74]等。交叉注意力机制的内部结构包含三个矩阵,分别是查询矩阵$W^Q$、键矩阵$W^K$和值矩阵$W^V$。输入模型的两个嵌入序列可以来自不同的模态,但必须具有相同的维度。交叉注意力机制能够计算序列之间的关联性,并根据查询序列的权重来融合与之相关的值序列。这种机制使得模型能够聚焦于序列中重要的相关信息,从而提高任务的性能和表现能力。

在本文的场景下,我们假设$X={\{x_i|x_i \in \mathbb{R}^D \}_{i=1}^B}$为第$t$个批次的图像数据,其中B表示批次大小。经过预训练的深度神经网络$f_{\theta}: \mathbb{R}^D \to \mathbb{R}^d$,我们可以得到嵌入特征$\mathcal{Z}=\{f_{\theta}(X)\}$$\forall z \in \mathcal{Z}, z \in \mathbb{R}^d$。在这里,$\theta$表示模型参数,$d$表示嵌入向量的维度。值得注意的是,为了方便后续计算,SSOC将$W^Q$$W^K$$W^V$的维度设置为$d \times d$。我们记第$t-1$个批次得到的类中心为$\mathcal{A}_{t-1}={\{a_i\}}_{i=1}^{S+N}$,特别地,$\mathcal{A}_0$表示随机初始化的类中心特征矩阵。其中$a_i$表示第$i$个类别的特征向量。在交叉注意力机制中,我们将类别特征视为查询输入、数据特征视为键输入和值输入,并使用相应的参数矩阵进行点积计算,即$Q=\mathcal{A}_{t-1}W^Q, K=ZW^K, V=ZW^V$,于是,我们采用的交叉注意力机制可以表示为:

$$$ \begin{equation}\label{equ1} \Delta \mathcal{A}={\rm SoftMax}(\frac{QK^T}{\sqrt{d_k}})V \end{equation}$$$

两个矩阵的点积通常可以代表向量间的相似程度,在上式中,$QK^T$是大小为$(S+N) \times B$的注意力矩阵,该矩阵中第$i$行、第$j$列的元素可以看作是第$i$个类中心与batch中第$j$个样本的相关性,其值越大,表示样本越有可能属于该类别。注意力分数经过SoftMax层后再与数据特征加权求和,最终得到大小为$(S+N) \times h$的交叉注意力矩阵$\Delta \mathcal{A}$,其中的第$i$行正好是第$i$个类中心与数据特征加权求和后的特征向量,在这个batch的数据中,与第$i$个类别越相似的样本对$\Delta \mathcal{A}_i$的贡献越大,反之贡献很小。因此,$\Delta \mathcal{A}$可以近似看作是该批次样本的类中心特征向量,与先前的类中心计算残差得到更新后的类中心$\mathcal{A}_{t}$:  

$$$ \begin{equation}\label{equ2} \mathcal{A}_{t}=\mathcal{A}_{t-1} + \Delta \mathcal{A} \end{equation}$$$

这样一来,$\mathcal{A}_{t}$既保留了大部分历史数据的类别信息,又以残差的方式引入了当前新发现的类别信息,这样。在下一次迭代中,它将作为上一步的类中心参与计算注意力分数。在若干次迭代之后,\ref{equ2}得到的类中心容纳了数据集中所有的类别信息,能够对,通过这种方式,SSOC利用交叉注意力机制实现了数据特征和类中心的动态交互,实现了类别的自学习。相较于简单地使用聚类来获取类中心,我们的方法不容易受到极端偏离分布规律的数据的影响,能够更好地捕捉样本与不同类别之间的相关特征。

在前面的部分,我们获得了类中心的残差表示,表示为$\Delta \mathcal{A}$。随后,SSOC采用基于距离的方法,利用点积相似性来衡量样本与同一特征空间内类中心之间的距离,以进行可解释的分类。然后,使用激活函数将这些距离归一化为概率分布。通过我们的实验证明,在已知类别和未知类别之间实现平衡学习过程存在挑战,通常导致对未知类别的学习缓慢,并过早地过拟合已知类别。为解决这个问题,我们调整了SoftMax函数计算标签数据预测概率时的温度参数,同时保留未标记数据的默认值为1。这种调整涉及使用较大的温度参数,导致较平坦的概率分布,而较小的值则使输出更加尖锐。通过修改温度参数,我们的目标是降低模型对已知类别的预测信心,从而减轻对已知类别的过拟合并增强对未知类别的学习。因此,对于样本$x_i$,其概率分布可以表示如下。

$$$\begin{equation}\label{equ11} p_i=\begin{cases} \sigma(z_i \cdot {\Delta \mathcal{A}}^T /\epsilon)& \text{$x_i \in \mathcal{D}^l$}\\ \sigma(z_i \cdot {\Delta \mathcal{A}}^T) & \text{Others.} \end{cases} \end{equation}$$$

其中$\sigma$表示SoftMax优化算子,$\epsilon$是缩放超参数,$p_{ij}$表示样本$x_i$属于类别j的预测概率。在后续讨论中,我们使用$\hat{p}_i=\max_j{p_{i}}$表示最大的置信分数,$\hat{y}_i=\arg \max p_i$表示样本的伪标签,然后将其转换为独热形式。最后,我们在图2中说明了SSOC的架构。

4.2 学习目标

为了辅助SSOC学习开放世界类别,我们设计了有助于未知类学习的优化目标。具体来说,我们的损失函数包含三部分:能够选择置信未标记数据的交叉熵损失(CE)、能够选择置信相关样本对的成对相似性损失(BCE)和能够防止已知类过拟合的正则化项(RE)。SSOC的总目标如下所示:

$$$ \begin{equation}\label{equ4} \mathcal{L}=\mathcal{L}_{CE}+\alpha \mathcal{L}_{BCE}+\beta \mathcal{L}_{RE} \end{equation}$$$

 其中$\alpha$$\beta$为平衡超参数。接下来的内容,我们将详细地介绍三个损失目标。

交叉熵

交叉熵是一种用于衡量两个概率分布之间差异的度量指标,广泛应用于分类任务和概率估计中。对于标记数据,SSOC通过最小化交叉熵损失来充分利用标记信息进行有监督指导。我们将标记数据$\mathcal{D}^l$真实标签的one-hot形式记作$Y^l$,使用公式3计算得标记数据的概率分布$P^l$,该部分的监督损失可以表示为以下形式。

$$$\begin{equation}\label{equ5} \mathcal{L}_{CE}^l=-\frac{1}{\mathcal{M}} \sum_{i=1}^{\mathcal{M}} \sum_{j=1}^{S+N} Y^l_{ij}\log{P^l_{ij}} \end{equation}$$$

对于无标记数据,它们缺乏所需的真实标签,无法直接计算交叉熵。为了解决这个问题,许多SSL方法使用伪标签来计算无标记样本的伪监督损失,以帮助模型挖掘未标记数据中的信息。然而,在许多情况下,数据集中的噪声数据可能会导致错误或不可靠的伪标签,这会干扰模型并降低性能。为了减少噪声干扰并保证模型的鲁棒性,我们使用阈值来筛选置信度高的伪标签用于模型训练。具体而言,通过公式3计算出无标记样本集$\mathcal{D}^u$的概率分布$P^u$和伪标签$\hat{Y}^u$,我们只选择最大置信分数高于阈值$\tau_1$的无标记样本来计算交叉熵损失。  

$$$ \begin{equation}\label{equ6} \mathcal{L}_{CE}^u=-\frac{1}{\mathcal{N}} \sum_{i=1}^{\mathcal{N}}{\mathbb I({\hat p}_i > \tau_1) \log{​{\hat p}_i}} \end{equation}$$$

其中$\mathbb{I}(\cdot)$是指示函数。综上所述,SSOC的总交叉熵损失是上述两部分的加权和。

 $$$\begin{equation}\label{equ7} \mathcal{L}_{CE}=\mathcal{L}_{CE}^l+\gamma \mathcal{L}_{CE}^u \end{equation}$$$

我们在多次实验中发现,在训练初期,模型对于未知类别的了解较少,在预测未标记数据时会产生较高的不确定性。因此,模型最初选择出来的未知类样本数量较少,这些有限的未知类信息不足以使模型对未知类别有一个整体的认知。然而,标记数据的真实标签是明确的,$\mathcal{L}_{CE}^l$的监督信息相对强烈,这可能导致模型向已知类别偏移,并错误地将未知类样本划分为单一的已知类别。为了减轻这种问题的影响,并增强模型对未知类别的关注,我们采用了以下两种策略:我们对未标记数据进行扰动。对于每个未标记图像样本,我们生成两个数据增强图像$r_1$$r_2$,将标记数据的特征向量与它们的特征向量拼接起来,共同用于计算公式1中的类中心增量。在计算伪监督损失时,我们仅保留$r_1$的图像数据,并将$r_2$的伪标签视为$r_1$的伪标签。通过这种方式,模型能够学习到更多未知类别的不变特征。

成对相似性

二元交叉熵损失通常用于处理二分类问题。为了让SSOC学习到更优秀的类别特征,我们使用BCE损失来约束样本对之间的相似关系。这个思想早在NCD中就已经成功实施。在嵌入空间中,两个样本的类别关系只有两种可能:同类或异类。BCE损失的目标是拉近同类样本,拉远异类样本。对于标记数据,我们直接使用真实标签来判断它们是否属于同一类别;对于无标记数据,我们使用样本对特征的余弦相似度来衡量它们之间的相似关系。为了减少不可靠噪声样本的负面影响,我们设置了阈值$\tau_2$,用于筛选具有足够置信度的样本对。

$$$ \begin{equation}\label{equ8} \mathcal{L}_{BCE}=-\sum_{i=1}^{\mathcal{M}+\mathcal{N}}{\sum_{j=1}^{\mathcal{M}+\mathcal{N}}{\mathbb{I}(\min({\hat p}_i, {\hat p}_j)>\tau_2)[s_{ij}\log p_i^T p_j + (1-s_{ij})\log (1-p_i^T p_j)]}} \end{equation} \begin{equation}\label{equ9} s_{ij}=\begin{cases}\mathbb{I}(y_i=y_j) & \text{$x_i,x_j \in \mathcal{D}_l$}\\\cos (z_i, z_j) & \text{Others.}\end{cases} \end{equation}$$$

在最小化BCE损失的过程中,原本特征非常相似的样本对(即$s_{ij}$接近于1),它们的概率分布也被优化得更加相似。相反,特征差异较大的样本对的概率分布被优化得更有差异。因此,在我们的方法中,BCE损失的作用是对齐预测空间和嵌入空间,从而帮助模型学习类别间的差异性和类内的相似性。

最大熵正则化

我们在实验中发现,在训练初始阶段,交叉熵损失对于SSOC的学习起到了主导作用,导致类中心容易聚集在一起,难以区分开。这样一来,所有的未标记数据可能都被错误地分为同一类别,而这不是我们所期望的结果。为了使预测的类别分布更加均匀,我们引入了最大熵正则化项,以增加模型预测的不确定性。最大熵正则化本质上是经验熵,通过观测到的数据频率来估计先验概率分布的不确定性。其形式如下所示。 

$$$ \begin{equation}\label{equ10} \mathcal{L}_{RE}=\sum_{i=1}^{\mathcal{M}+\mathcal{N}}p_i \log p_i \end{equation}$$$

在SSOC中,我们将上述公式应用于所有样本,通过最大化经验熵,使模型的预测分布更加灵活和多样化,从而为未标记数据分配给各个类别提供更多机会。实验证明,最大熵正则化项能够有效提高模型在未知类上的鲁棒性。在算法1中,我们详细描述了SSOC的训练过程。

5、实验

在本节中,我们详细介绍了SSOC的实验设置,并给出了结果分析。

5.1 实验设置

数据集

为了证明SSOC的有效性,我们在CIFAR-10[75]、CIFAR-100[75]、ImageNet-100[76]这三个常用的计算机视觉基准数据集上进行实验。CIFAR-10和CIFAR-100数据集都包含了60000张分辨率为$32 \times 32$的图像,其中50000张用于训练,10000张用于测试,CIFAR-10共有10个类别,每个类别约有6000张图像,而CIFAR-100有100个类别,每个类别约有600张图像。ImageNet-100是从包含1000个类别的ILSVRC2012数据集中选择了100个类别得到的[77],为了方便与其他研究进行对比,我们采用了ORCA和NACH[1,2]中所使用的100个类别。对于所有数据集,我们采用随机裁剪和旋转的数据增强方法。在主实验中,我们将每个数据集前50%的类别视作已知类别,其余视作未知类别,并使用已知类别数据的50%作为标记数据,其余则与未知类数据一起作为未标记数据。除此之外,我们还展示了不同标记比例和新类别比例的实验结果。在本文所有的实验中,我们都采用了随机数据划分,以确保实验结果的普适性。

对比方法

我们将SSOC与SSL、open-set SSL、NCD和现有的open-world SSL方法进行比较。对于SSL和open-set SSL方法来说,它们只能对已知类别进行分类,为了将它们扩展到开放世界SSL场景,需要对未知类别进行K-means聚类,以评估它们在新类别上的性能。我们选择FixMatch作为SSL的代表方法,$\rm DS^3L$和CGDL作为open-set SSL的代表方法。由于SSL方法本身没有新类概念,我们将SoftMax输出的置信度低的样本视作未知类别。对于NCD类方法,我们选择DTC和RankStats作为对比,由于它们只能对未知类别进行聚类,缺乏已知类别上的性能,我们使用匈牙利算法对标记数据中的已知类别和聚类所得类别进行最大权匹配[78,1,2],然后分别评估这些已知类别的结果。对于open-world SSL,我们比较了ORCA[1]和NACH[2]中报告的性能。

实施细节

对于CIFAR-10数据集,我们用ResNet34作为骨干网络,并使用两个Adam优化器分别优化骨干网络和交叉注意力层。对于骨干网络,我们使用1e-4的较小学习率进行微调,而对于交叉注意力层,我们采用稍大的学习率5e-3,侧重于学习类别信息。两个优化器的动量参数均设置为$(0.9,0.99)$。我们使用批大小为128进行训练,共进行200个epoch。对于CIFAR-100数据集,用Resnet18作为骨干网络,骨干网络和交叉注意力层的学习率都设置为1e-4,批大小为512,共进行500个epoch的训练。对于ImageNet-100数据集,用Resnet50作为骨干网络,骨干网络和交叉注意力层的学习率分别设置为1e-5和3e-4,使用批大小为100进行训练,共进行200个epoch。在所有实验中,我们采用了早停策略,并用余弦退火方法动态调整学习率。我们在8个V100 GPU上完成了CIFAR数据集的实验,并用4个NVIDIA 3090部署ImageNet-100的实验。

我们使用经过无监督预训练的ResNet模型来提取更优质的图像特征。在开始训练之前,我们首先利用预训练的骨干网络提取所有数据的初始嵌入向量 $\mathcal{Z}^l \cup \mathcal{Z}^u$,然后通过应用K-means算法对这些向量进行无监督聚类,得到初始化的类中心表示 $\mathcal{A}_0$。这种含有先验知识的类中心初始化有助于模型的学习过程。

评价指标

我们采用了方法[1,2]中使用的评估方式,报告了SSOC在已知类别、未知类别和所有类别上的准确率。另外,在消融实验中,我们还报告了新类上的标准化互信息(NMI)。值得注意的是,由于模型学习到的新类概念是无序的,在计算未知类别和所有类别上的准确率之前,需要使用匈牙利最大权匹配算法做标签对齐,获取未知类的聚类标签和真实标签的最优匹配方式。

5.2 主要结果

主要结果比较

我们在表1中给出了SSOC和对比方法在CIFAR-10、CIFAR-100和ImageNet-100上的分类准确率,在该组实验中,所有方法均使用50%的标记比率和50%的新类比率。观察到在所有数据集上,SSOC在已知类别、未知类别和所有类别上的准确率均优于SSL、open-set SSL和NCD方法在开放世界SSL场景下的扩展结果。与这些方法中表现最佳的RankStats相比,我们在CIFAR-100和ImageNet-100的全部类别上分别提升了30%和42.4%。另外,SSOC还优于两个open-world SSL方法,相较于NACH[2],我们在CIFAR-10的已知类上取得了3.8%的改进,并在具有挑战性的ImageNet-100上,显著提高了未知类的性能,在未知类和所有类别上分别提高了2.9%和3.1%。我们的实验结果表明,SSOC能够有效解决开放世界SSL问题。

改变标记比率

为了证明SSOC在少量标记数据场景下的有效性,我们固定新类比率为50%,对比了标记比率为10%和30%时,ORCA、NACH和SSOC的性能。表2展示了所有类别的准确率,其中部分ORCA、NACH实验结果来自它们的论文。从表中可以观察到,随着标记数据的减少,三个方法的性能都会下降,但是在CIFAR-10和ImageNet-100上,当标记数据从50%减少到10%时,SSOC仅产生了1.89%和7.22%的性能下降,而NACH分别下降了3.2%和12.79%。此外,在ImageNet-100数据集上,当标记比率为10%和30%时,SSOC的总准确率比NACH高出8.66%和$6.09%。遗憾的是,我们在CIFAR-100上标记比率为30%时的结果略差于NACH,但仍比ORCA高出6.95%。总体而言,SSOC具有较强的鲁棒性,能够很好地应对标记数据不足的情况。

改变新类比率

另外,我们固定标记比率为50%,研究了不同新类比率对open-world SSL方法的影响。表3进一步给出了ORCA、NACH和SSOC在10%、30%、70%、90%的新类比率时的全类准确率,由于ORCA和NACH的论文中缺乏这些实验数据,我们复现了他们开源的代码。观察数据发现,随着未知类别增多,三个方法的性能都出现了下降,但SSOC的下降幅度小于另外两个方法。值得关注的是,SSOC在较高新类比率的情景下取得了激动人心的成绩,当新类比率为90%时,SSOC在三个数据集上比NACH高出了10.61-22.3%,在新类比率为70%的ImageNet-100上,相对于ORCA的准确率提升了16.01%。此外,我们在新类比率为10%的CIFAR数据集上没有超过NACH,这表明NACH更注重已知类别的分类,而SSOC更加关注未知类别的发现,在只有少量已知类别的情况下,也能取得优秀的效果。

通过以上三个实验,我们证明了SSOC的有效性以及其出色的鲁棒性和泛化性能,它能比ORCA、NACH更好地应对标记数据缺乏或新类数量众多的情况,应用场景更为广泛,具有很强的现实意义。

5.3 消融实验

为了验证不同损失函数的有效性,我们在ImageNet-100上进行了消融实验,其中标记比率和新类比率均为50%。表4报告了已知类、未知类和全类上的准确率,并计算了未知类上的NMI。在前三行中,我们分别删除了交叉熵损失、成对相似性损失和最大熵正则化项,并将剩余损失作为目标函数。从实验结果中可以观察到,$\mathcal{L}_{CE}$对SSOC有至关重要的作用,模型需要依赖标记数据的监督损失提供重要的地面真实信息,并通过伪监督损失来学习未知类。此外,在删除 $\mathcal{L}_{RE}$ 的实验中,未知类的性能严重下降,证实了最大熵正则化对于未知类的学习是有帮助的。最后,我们发现$\mathcal{L}_{BCE}$能进一步优化模型的整体性能。

5.4 参数敏感性

阈值选择的影响

为了分析阈值对实验结果的影响,我们在CIFAR-100数据集上进行了实验,使用不同$\tau_1$$\tau_2$,标签比例和新颖比例均为50%。在图3(a)和(b)中,我们改变了$\tau_1$的值,并展示了每个训练时期选择的阈值所选未标记样本的数量(a),以及所选未知类别样本的伪标签准确性(b)。可以观察到,较低的$\tau_1$不足以过滤掉错误标记的样本,这可能会干扰模型的分类能力。另一方面,较大的$\tau_1$过度消除了未标记数据,未能为模型提供足够的未知类别样本学习,导致在未知类别上性能不佳。在(c)中,我们提供了不同$\tau_2$值的所选未知类别样本的伪标签准确性。在所有CIFAR-100实验中,我们将$\tau_1$设置为0.6,将$\tau_2$设置为0.8。

损失平衡超参数对损失的影响

为了调查不同损失权重对结果的影响,我们在图4中呈现了SSOC在ImageNet-100数据集上使用两组损失权重的准确性。可以观察到,提高与未知类别相关的损失项的权重似乎会限制模型学习未知类别的能力,进而影响整体性能。这表明损失项$\mathcal{L}_{CE}^u$, $\mathcal{L}_{BCE}$, and $\mathcal{L}_{RE}$与未知类别的学习密切相关,仅强调未知类别的学习可能对整体分类造成伤害。我们可以假设一个极端情况:当监督损失$\mathcal{L}_{CE}^l$可以忽略不计时,SSOC会退化为一个在未标记数据上的聚类算法,没有关于真实类别的任何信息。当使用匈牙利最大权重算法进行标签匹配时,预测标签和真实标签之间的不正确匹配增加,导致未知类别和整体分类的结果不理想。因此,我们需要找到一组最佳的损失平衡权重。

6、总结

在这项工作中,我们提出了SSOC来解决开放世界SSL问题。SSOC利用交叉注意力机制自主学习开放世界中的类别,并利用成对相似性损失从未标记数据中提取信息,通过实例预测和关系发现新颖类别。我们在三个计算机视觉基准数据集上展示了SSOC的有效性,它在性能上优于最先进的基线方法。此外,SSOC在面对有限标记数据和许多新颖类别等挑战时表现出色。

  • 8
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值