OpenLDN

题目:《OpenLDN: Learning to Discover Novel Classes for Open-World Semi-Supervised Learning》

来源:ECCV 2022

欧洲计算机视觉会议(European Conference on Computer Vision)

Abstract

        半监督学习(SSL)是解决监督学习注释瓶颈的主要方法之一。最近的 SSL 方法可以有效地利用大量未标记数据来提高性能,同时依赖于少量标记数据。大多数 SSL 方法的一个常见假设是标记和未标记数据来自相同的数据分布。然而,在许多现实世界场景中,这种情况很难实现,这限制了它们的适用性。在这项工作中,我们尝试解决具有挑战性的开放世界 SSL 问题,该问题不作此假设。在开放世界 SSL 问题中,目标是识别已知类别的样本,并同时检测并聚类属于未标记数据中存在的新类别的样本。这项工作介绍了 OpenLDN,它使用 成对相似性损失 来发现新类别。利用 双层次优化规则,这种成对相似性损失利用标记集中可用的信息来隐式地聚类新类别样本,同时识别来自已知类别的样本。在发现新类别后,OpenLDN 将开放世界 SSL 问题转换为标准 SSL 问题,以使用现有的 SSL 方法实现额外的性能提升。我们广泛的实验表明,OpenLDN 在多个流行的分类基准上超越了当前的最先进方法,同时提供了更好的准确性/训练时间权衡

1 Introduction

        深度学习方法在具有挑战性的监督学习任务上取得了显著进展。然而,监督学习范式假设可以获得大量手动标记的数据,这些数据获取既耗时又昂贵。为了解决这个问题,提出了几种方法,包括半监督学习、主动学习、自监督学习、迁移学习和少样本学习。在这些方法中,半监督学习(SSL)是一种主要的方法,它通过利用大量未标记数据的集合来减少所需的注释量

        尽管最近的 SSL 方法取得了有希望的结果,但它们的主要假设是标记和未标记数据来自相同的分布。然而,这个假设在许多现实世界场景中(开放世界问题)很难满足。例如,未标记数据通常从网络来源挖掘,可能包括未知类别的示例。已经证明,使用这些示例进行训练通常会降低标准SSL方法的性能。为了减轻未知(新)类别的未标记样本的负面影响,提出了不同的解决方案。然而,它们的主要动机仅仅是忽略新类别样本,以防止已知类别的性能下降。与此相反,ORCA 将SSL 问题概括为包含新类别,目标不仅是保持对已知类别的性能,而且是识别新类别的样本。这种现实的 SSL 设置称为开放世界 SSL 问题,是这项工作的重点

        这项工作提出了 OpenLDN,它采用成对相似性损失来发现新类别。这种损失解决了一个成对相似性预测任务,确定一对图像是否属于同一类别。从本质上讲,这个任务类似于无监督聚类问题,通过识别一致的簇来促进新类别的发现。解决成对相似性的关键挑战是在没有访问它们的类别标签的情况下确定一对图像之间的相似性关系。一种常见的克服这一挑战的方法是基于预训练的无监督/自监督特征来估计成对相似性关系。然而,这个过程在计算上很昂贵。为了避免依赖无监督/自监督预训练,相反,我们利用已知类别的标记示例中可用的信息来解决成对相似性预测任务,并引入了一个成对相似性预测网络来生成一对图像之间的相似性分数。为了更新这个网络的参数,我们采用了双层次优化规则,它将已知类别的标记示例中可用的信息转移到学习未知类别上。特别是,我们基于标记示例的交叉熵损失隐式地优化相似性预测网络的参数。这样,我们就在不依赖于无监督/自监督预训练的情况下解决了成对相似性预测任务,这使得整体训练更加高效,同时提供了显著的性能提升

        基于输出概率学习成对相似性关系会导致根据最可能的类别隐式地发现簇,因此,新类别的发现。一旦我们学会识别新类别,我们就可以为新类别样本生成伪标签。这随后允许我们将开放世界SSL问题转换为封闭世界SSL问题,通过利用未标记样本生成的伪标签将新类别样本纳入标记集。这种将开放世界问题转换为封闭世界问题的独特视角特别强大,因为它允许我们利用任何现成的封闭世界SSL方法来实现进一步的改进。然而,这种策略的一个缺点是,为新类别生成的伪标签往往是嘈杂的,这反过来又可能阻碍随后的训练。为了解决这个问题,我们引入了迭代伪标签,这是一种简单有效的方法来处理新类别的噪声伪标签估计。总之,我们的主要贡献是:

(1)我们提出了一个新颖的算法OpenLDN,用于解决开放世界SSL。OpenLDN应用双层次优化规则来确定成对相似性关系,不依赖于预训练的特征

FedoSSL使用预训练

(2)我们提出通过发现新类别将开放世界 SSL 转换为封闭世界 SSL 问题;这允许我们利用任何现成的封闭世界 SSL 方法来进一步提高性能

(3)我们引入了迭代伪标签,这是一种简单有效的方法来处理新类别的噪声伪标签

(4)我们的实验表明,OpenLDN 在显著的差距上超越了现有的最先进方法

2 Relate Works
 

        半监督学习(SSL)是处理监督学习中标签注释瓶颈的流行方法。通常,这些方法为封闭世界设置而开发,其中未标记集合仅包含来自已知类别的样本。封闭世界 SSL 的两种最主要方法是 一致性正则化 和 伪标签。基于一致性正则化的方法最小化不同增强版本的图像之间的一致性损失,以从未标记样本中提取显著特征。基于伪标签的方法通过在标记数据上训练的网络生成未标记样本的伪标签,然后以监督的方式对它们进行训练。最后,混合方法结合了一致性正则化和伪标签

        最近的工作表明,未标记集合中存在新类别样本对已知类别的性能有负面影响。为了解决这个问题,提出了不同的解决方案。在一项工作中,训练了一个权重函数来降低新类别样本的权重。另一项工作基于置信度分数过滤掉新类别样本。在另一项工作中,引入了加权批量归一化以实现对新类别样本的鲁棒性。然而,这些方法都没有尝试解决具有挑战性的开放世界 SSL 问题,其中的目标是检测新类别样本并对它们进行分类。据我们所知,ORCA 是唯一解决这个问题的方法,它通过引入基于不确定性感知的自适应边界的交叉熵损失来减轻训练初期已知类别的过度影响。然而,为了发现新类别,ORCA 依赖于自监督预训练,这在计算上是昂贵的。为了克服对自监督预训练的依赖,OpenLDN 中的成对相似性损失利用已知类别的标记示例中可用的信息,使用双层次优化规则

        新类别发现问题 与 无监督聚类 密切相关。新类别发现和无监督聚类之间的关键区别在于前者依赖于额外的标记集来学习新类别。为了发现新类别,一项工作执行自监督预训练,然后基于自监督特征的秩统计解决成对相似性预测任务。另一项工作扩展了深度聚类框架以发现新类别。成对相似性预测任务也应用于通过从已知类别转移知识来对新类别进行分类。虽然新类别发现方法通常使用多个目标函数,但一项工作使用多视图伪标签和交叉熵损失训练来简化这一点。开放世界 SSL 和新类别发现之间的关键区别在于前者不假设未标记数据只包含新类别样本。因此,新类别发现方法并不直接适用于开放世界 SSL 问题。此外,我们的实验表明,OpenLDN 在开放世界 SSL 中的表现明显优于适当修改过的新类别发现方法

3 Method

        为了识别来自已知类别和新类别的未标记样本,我们引入了 成对相似性损失(pairwise similarity loss) 来隐式地将未标记数据聚类为已知和新类别。这种隐式聚类促进了新类别的发现,并通过 交叉熵损失 和 熵正则化项 进行了补充。接下来,我们为新类别样本生成伪标签,将原始的开放世界 SSL 问题转换为封闭世界 SSL 问题。这种转换允许我们利用现有的现成的封闭世界 SSL 方法来学习已知和新类别,从而获得进一步的收益。我们在图1中提供了我们方法的概述。以下,我们提出问题公式化,并详细介绍我们的方法

3.1 Problem Formulation

表示标量,表示向量,表示矩阵,表示集合

在矩阵中,第一个索引表示行,第二个索引表示列

在开放世界 SSL 问题中,假设有一个标记数据集和未标记数据集

表示标记数据集,一共有个样本

其中,是标记样本,是其对应的标签,属于个已知类别之一

类似地,个未标记样本组成

其中,属于个类别之一,这里类别的总数

然而,在开放世界设置中,包含一些不属于任何已知类别的样本。属于未知类别的样本称为新类别样本,其中每个样本属于个新类别之一,即在开放世界设置中,

C(unlabel)=C(label)+ C(novel)

C 是 Class,类别的数量

3.2 Learning to Discover Novel Classes

为了发现新类别,OpenLDN 利用一个神经网络,参数化为,作为特征提取器。特征提取器通过将输入图像投影到嵌入空间来生成特征嵌入(feature embedding),即

特征嵌入是将数据的原始特征表示转换成数值型向量的过程,这些向量捕捉了数据的重要属性和模式。特征嵌入通常用于机器学习和深度学习中,以改善模型的性能和效率

这里,分别是输入图像和特征嵌入的集合

接下来,为了识别来自新类别的样本,以及对已知类别的样本进行分类,我们应用了一个分类器,参数化为,这个分类器将嵌入向量投影到输出分类空间,即

在这个输出空间中,前logits 对应已知类别,其余的 logits 属于新类别,使用 softmax 激活函数从这些输出得分获得 softmax 概率分数,即

"logits"是模型最后一层输出的原始值,这些值在应用激活函数之前未经转换,通常用于分类任务

发现新类别(discover novel classes)的整体目标:

  1. 成对相似性损失(pairwise similarity loss)
  2. 交叉熵损失(cross-entropy loss)
  3. 熵正则化项(entropy regularization term)

成对相似性损失帮助网络发现新类别,交叉熵损失通过利用真实标签和生成的伪标签来帮助分类已知和新类别,熵正则化有助于避免平凡解

整体目标函数如下:

经过的训练后,被分配到任意最后个 logits 的样本被视为新类别样本

Pairwise Similarity Loss

发现新类别是我们提出方法的核心部分,这是一个无监督聚类问题,可以表示为成对相似性预测任务。特别是,对于聚类,一对图像只能有两种可能的关系,要么它们属于同一个聚类,要么不属于。然而,解决成对相似性预测任务需要监督。之前的方法尝试通过基于预训练特征找到最近邻(标记为同一聚类的成员)来为所有图像对生成成对伪标签。然而,这种方法计算量大,并且受到最近邻估计噪声的影响

        

与这种方法截然不同的是,我们不是依赖预训练来获得成对相似性预测任务的标签,而是基于更可靠的可用的真实注释来学习估计成对相似性分数。为此我们引入了一个成对相似性预测网络

,参数化为。给定一对嵌入向量输出一个成对相似性分数,即

嵌入向量(Embedding Vector)是一种将高维的离散数据(如单词、句子、图像等)转换为低维的连续向量表示的方法。这种转换使得这些数据能够在向量空间中进行数学运算和进一步的处理

的成对相似性分数可以用作最小化成对相似性损失的监督信号。为此,给定一批图像,我们计算所有图像对的输出概率之间的余弦相似度。然后,在我们的成对相似性损失中,我们最小化输出概率的计算余弦相似度分数和估计的成对相似性分数之间的损失

l2损失 即 欧几里得距离

请注意,最小化输出概率的余弦相似度的成对相似性损失至关重要,因为这将隐式地根据最大概率分数形成聚类,从而识别新类别。成对相似性损失如下:

其中,是输出概率矩阵,是特征嵌入矩阵,表示余弦相似度函数

为了优化的参数,我们设计了一个双层次优化(bi-level optimization)过程。由于我们无法访问任何未标记样本的标签,特别是来自新类别的样本,我们利用属于已知类别的标记样本的真实标签。这个双层次优化的主要动机是获得一组参数,这些参数不会降低在已知类别上的性能。因此我们基于标记样本计算的交叉熵损失来优化

优化过程如下:首先,我们使用在 公式1 中引入的组合损失更新特征提取器和分类器的参数,以发现新类别

其中,表示优化参数的学习率

接下来,我们使用监督交叉熵损失,

在标记样本上计算来更新的参数。这里的是真实标签矩阵

更新规则如下:

其中,是优化参数的学习率

由于在 公式4 中的中不是显式的,我们执行双层次优化来计算。这种嵌套优化在大多数支持自动微分的现代深度学习中包中都是可用的,这种双层次优化过程确保了的参数以这样的方式更新,即在已知类别上的分类性能不会降低,因为这是开放世界 SSL 中的主要目标之一

Learning with Labeled and Pseudo-Labeled Data

在上述内容中,我们介绍了成对相似性损失,以通过解决成对相似性预测任务来识别新类别。回想一下,我们的目标是在未标记集中识别新类别的样本,并对已知类别进行分类,同时只允许访问已知类别的有限数量的注释。利用这些可用注释的直接方法将是最小化标记样本上的交叉熵损失。然而,这种方法可能会由于它们的强烈训练信号而对已知类别产生强烈的偏见[7]。为了减轻这种偏见并更有效地利用未标记样本,我们为所有未标记数据生成了伪标签。生成的伪标签可以与真实标签一起使用,以最小化交叉熵损失

按照[45,2,55]中的常见做法,我们基于网络输出概率生成伪标签。为了减少使用不可靠的伪标签进行错误训练的可能性,我们仅在足够自信的预测上生成伪标签。此外,我们基于伪标签的交叉熵学习满足了 SSL 作品中常用的另一个目标,即一致性正则化。这种目标鼓励输出分布对扰动具有不变性,以便决策边界位于低密度区域[12,71]。实现此目标的一种方法是最小化两个随机变换版本的图像的输出概率之间的散度。然而,这会向损失中添加另一个项,并且相应地增加了一个新的超参数。更优雅的方法是使用一个图像的弱变换版本生成的伪标签作为另一个版本的靶标。我们使用从图像的弱变换版本生成的伪标签作为其强增强版本的的目标。我们在下面陈述了我们的伪标签生成过程:

其中,(二元分类的中点),以避免每个数据集的微调。生成伪标签后,我们将它们与真实标签结合起来,,并使用交叉熵损失训练模型。在实践中,我们在一个批次内组合这两组标签。让表示一个批次,这个集合上的交叉熵损失定义为:

其中,是真实标签和生成的伪标签的独热编码(one-hot)矩阵。

Entropy Regularization

将未标记数据分配到不同类别的一个众所周知的缺点是,基于判别性损失(如交叉熵)可能导致所有未标记样本被分配到同一类别的平凡解[78,8,7,20]。我们的成对相似性损失也存在同样的问题,因为这种解决方案也可以最小化我们在公式2中的成对相似性损失。为了解决这个问题,我们在训练目标中引入了一个熵正则化项。实现这一点的一种方法是独立地将熵正则化应用于每个样本的输出。然而,这种熵最小化方式会导致个别输出概率的大幅变化,结果会导致新类别样本的任意类别分配。为了避免这个问题,我们对聚合统计数据应用了熵正则化,在我们的情况下,是整个批次样本概率的平均值。这个熵正则化项防止了一个类别在整个批次中占主导地位,其中大多数未标记样本只被分配到一个类别。这个项不会干扰平衡的类别分配。熵正则化定义为:

其中,是批次的平均概率,b表示批次中的样本数量。
 

3.3 Closed-World Training with Iterative Pseudo-Labeling

在我们发现未标记数据中的新类别后,我们可以将开放世界 SSL 问题重新表述为封闭世界问题以提高性能。为此,我们使用 公式8 为所有新类别样本生成了伪标签:

接下来,使用生成的伪标签,我们将新类别样本添加到标记集中。此时,我们可以应用任何标准封闭世界SSL方法[6,75,63,69]。不幸的是,伪标签往往包含噪声,这可能会阻碍性能。为了减轻噪声的负面影响,我们提议在封闭世界 SSL 训练期间以迭代方式进行伪标签生成。这种新的迭代伪标签方法可以与 EM 算法相关。从这个角度来看,我们迭代地尝试更新伪标签(期望步骤),并通过最小化这些更新后的伪标签上的损失来训练网络(最大化步骤)。重要的是要注意,OpenLDN,包括最终的封闭世界 SSL 重新训练,在计算上比基于无监督/自监督预训练的方法更轻或相当。此外,将开放世界 SSL 问题转换为封闭世界问题是一个通用解决方案,也可以应用于其他方法。我们在补充材料中提供了 OpenLDN 的整体训练算法

4 Experimental Evaluation

为了展示 OpenLDN 的有效性,我们在五个常见的基准数据集上进行了实验:CIFAR-10、CIFAR-100、ImageNet-100、Tiny ImageNet和 Oxford-IIIT Pet  数据集。CIFAR-10 和 CIFAR-100 数据集都包含60K图像(分为50K/10K的训练/测试集),并且分别有 10 个和 100 个类别。ImageNet-100数据集包含 100 个来自 ImageNet 的图像类别。Tiny ImageNet 包含来自 200 个类别的 100K/10K 训练/验证图像。最后,Oxford-IIIT Pet 包含来自 37 个类别的图像,分为 3680/3669 的训练/测试集。在我们的实验中,我们根据已知和新类别的百分比划分了这些数据集。我们将前个类别视为已知类别,其余的视为新类别。对于已知类别,我们随机选择了一部分数据来构建标记集,并将其余部分添加到未标记集中,同时包括所有新类别样

在实现细节方面,除了在 ImageNet-100 数据集上的实验外,我们在所有实验中都使用 ResNet-18作为特征提取器。在 ImageNet-100 数据集上,我们使用 ResNet-50。我们使用一个包含单隐藏层(维度为100)的多层感知机(MLP)来实例化我们的成对相似性预测网络。分类器是一个单一的线性层。为了发现新类别,我们在所有实验中训练了50个周期,批量大小为200(对于ImageNet-100为480)。我们总是使用Adam优化器[35]。对于训练特征提取器和分类器,我们将学习率设置为5e-4(对于ImageNet-100为1e-2)。对于成对相似性预测网络,我们使用学习率为1e-4。我们使用两种流行的封闭世界SSL方法,Mixmatch 和 UDA,进行第二阶段的封闭世界SSL训练。为了在这次封闭世界训练中保持数据平衡,我们为每个新类别选择了相同数量的伪标签。对于迭代伪标签,我们每10个周期生成一次伪标签。更多的实现细节可以在补充材料中找到。

在评估指标方面,我们报告了已知类别的标准准确性。此外,按照[26,25,7,20],我们还报告了新类别的聚类准确性。我们使用匈牙利算法来对预测和真实标签进行对齐,然后测量分类准确性。最后,我们还使用匈牙利算法报告了新类别和已知类别联合准确性

  • 9
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值