开放世界的半监督学习OPEN-WORLD SEMI-SUPERVISED LEARNING

开放世界的半监督学习

摘要

在现实世界中应用半监督学习的一个基本限制是假设未标记的测试数据只包含以前在标记的训练数据中遇到的类别。然而,这个假设对于野外的数据很少成立,因为在测试时可能会出现属于新类的实例。

在这里,我们引入了一个新的开放世界的半监督学习环境,它将新的类可能出现在未标记的测试数据中这一概念正式化。在这种新的设置中,目标是解决已标记和未标记数据之间的类别分布不匹配问题,在测试时,每个输入实例要么需要被归入现有的类别之一,要么需要初始化一个新的未见过的类别并将实例分配给它。

为了解决这个具有挑战性的问题,我们提出了ORCA,这是一种端到端的方法,它将实例分配到以前见过的类中,或者通过对类似的实例进行分组而形成新的类,而无需假设任何先验知识。ORCA的关键思想是利用不确定的自适应余量来规避因学习已见类比学习新类更快而造成的对已见类的偏见。通过这种方式,ORCA在训练过程中逐渐提高了模型的可辨别性,并减少了所见类与新颖类的类内方差之间的差距。在图像分类数据集和单细胞数据集上进行的大量实验表明,ORCA始终优于其他基线,在ImageNet数据集上实现了25%的可见类改进和96%的新类改进。

1.引言

随着深度学习的出现,已经取得了显著的突破,目前的机器学习系统在具有大量标记数据的任务上表现出色(Hinton等人,2012;LeCun等人,2015;Silver等人,2016;Esteva等人,2017)。尽管有这些优势,但绝大多数模型都是为封闭世界的设置而设计的,其根源在于假设训练和测试数据来自同一组预定义的类别(Bendale & Boult,2015;Boult等人,2019)。然而,这个假设对于野外的数据来说很少成立,因为标记数据取决于是否拥有特定领域的完整知识,而这在实践中很少发生。例如,生物学家可能会预先标记一些已知的细胞类型(见过的类别),然后想要训练并将模型应用于新的组织,以识别已知的细胞类型,但也要发现以前未知的新细胞类型(未见过的类别)。同样,在社交网络中,人们可能想把用户分类到预定的兴趣组中,同时也想发现新的未知/未标记的用户兴趣。因此,与通常假设的封闭世界相反,许多现实世界的问题本质上是开放的--在测试数据中可能会出现在训练期间从未见过(和标记)的新类别。

在这里,我们介绍了开放世界半监督学习(open-world SSL)的设置,它概括了半监督学习和新类发现。在开放世界的SSL下,我们得到了一个标记的训练数据集和一个未标记的数据集。有标签的数据集包含了属于一组所见类的实例,而无标签/测试数据集中的实例既属于所见类,也属于未知数量的未见类(图1)。在这种情况下,模型需要同时将实例分类到以前看到的类中,然后发现新的类并将实例分配给它们。换句话说,开放世界SSL是在类别分布不匹配的情况下的一种过渡性学习设置,其中未标记的测试集可能包含在训练期间从未被标记的类别,即不属于标记的训练集。鉴于未标记的测试集,模型需要将实例分配给先前在标记集中看到的类之一,或者形成一个新的类并将实例分配给它。

开放世界的SSL与最近的两条工作路线有根本的不同,但又密切相关:

鲁棒半监督学习(SSL)和新的类发现。稳健的SSL(Oliver等人,2018;Guo等人,2020;Chen等人,2020b;Guo等人,2020;Yu等人,2020)假设标记数据和未标记数据之间的类分布不匹配,但在这种情况下,模型只需要能够识别(拒绝)未标记数据中属于新类的实例为分布外的实例。相比之下,开放世界SSL不是拒绝属于新类的实例,而是旨在发现个别新类,然后将实例分配给它们。

新颖类的发现(Hsu等人,2018;2019;Han等人,2019;2020;Zhong等人,2021)是一个聚类问题,人们假设未标记的数据只由新颖类组成。(really?)相比之下,开放世界的SSL更具有普遍性,因为未标记数据中的实例可以来自于已见的以及新的类。

  • 为了将稳健的SSL和新类发现方法应用于开放世界的SSL,原则上可以采用多步骤的方法,即首先使用稳健的SSL来拒绝来自新类的实例,然后在被拒绝的实例上应用新类发现方法来发现新类。
  • 另一种方法是,我们可以把所有的类当作 "新的",应用新的类发现方法,然后把一些类与标记数据集中的所见类相匹配。

然而,我们的实验表明,这种临时性的方法在实践中表现并不理想。因此,有必要设计一种能够在端到端框架中解决这一实际问题的方法。

在本文中,我们提出了ORCA(Open-woRld with unCertainty based Adaptive margin),在新的开放世界SSL设置下运行。ORCA有效地将未标记的数据中的例子分配到以前见过的类别中,或者通过对类似实例的分组形成一个新的类别。ORCA是一个端到端的深度学习框架,其中我们的方法的关键是一个新颖的不确定性自适应边际机制,在训练过程中逐渐减少可塑性,增加模型的可辨别性。这种机制有效地减少了由于学习所见类比学习新类更快而造成的所见类内部方差之间的不希望的差距,我们表明这是这种设置中的一个关键困难。然后,我们开发了一个特殊的模型训练程序,学习将数据点分类到一组先前看到的类中,同时也学习为每个新发现的类使用一个额外的分类头。已见类的分类头被用来将未标记的例子分配到标记集的类中,而激活额外的分类头则允许ORCA形成一个新的类。ORCA不需要提前知道新类的数量,可以在部署时自动发现它们。

我们在三个适合开放世界SSL的基准图像分类数据集和一个来自生物学领域的单细胞数据集上评估了ORCA。由于没有现有的方法可以在开放世界的SSL环境下运行,我们首先将现有的最先进的SSL、开放集识别和新类发现方法扩展到开放世界的SSL环境中,然后将它们与ORCA进行比较。 实验结果表明,ORCA有效地解决了开放世界SSL的挑战,并始终以很大的幅度超过了所有基线。具体来说,ORCA在ImageNet数据集的可见类和新奇类上实现了25%和96%的改进。此外,我们表明ORCA对未知数量的新类、所见和新类的不同分布、不平衡的数据分布、预训练策略和少量标记的例子都很稳健。

相关工作

我们总结了开放世界SSL和相关设置的相似性和差异。其他相关的工作在附录A中给出。

新颖的类发现。

在新类发现中(Hsu等人,2018;Han等人,2020;Brbic等人,2020;Zhong等人,2021),任务是对由类似但完全不相干的类组成的未标记数据集进行聚类,而这些类是用来学习更好的聚类表示的。这些方法假设在测试时所有的类都是新的。虽然这些方法能够发现新的类,但它们不能识别已见/已知的类。相反,我们的开放世界SSL更具有普遍性,因为未标记的测试集由新的类组成,但也包括以前在标记的数据中看到的需要识别的类。原则上,人们可以通过在测试时将所有的类视为 "新的",然后将其中一些类与标记数据集中的已知类相匹配来扩展新的类发现方法。我们采用这样的方法作为我们的基线,但我们的实验表明,它们在实践中的表现并不理想。

半监督学习(SSL)。SSL方法(Chapelle等人,2009;Kingma等人,2014;Laine和Aila,2017;Zhai等人,2019;Lee,2013;Xie等人,2020;Berthelot等人,2019;2020;Sohn等人,2020)假设封闭世界的设置,其中标记的和未标记的数据来自相同的类集。健全的SSL方法(Oliver等人,2018;Chen等人,2020b;Guo等人,2020;Yu等人,2020)通过假设来自新类的实例可能出现在未标记的测试集中而放松SSL假设。稳健SSL的目标是拒绝来自新类的实例,这些实例被视为分布外的实例。在开放世界的SSL中,目标不是拒绝来自新类的实例,而是发现个别新类,然后将数据点分配给它们。为了将稳健的SSL扩展到开放世界的SSL,我们可以利用被丢弃的点,然后应用聚类/新类的发现。早期的工作(Miller & Browning, 2003)考虑用这样的方式来解决这个问题,即使用EM算法的扩展。然而,我们的实验表明,通过丢弃实例,这些方法学到的嵌入并不能准确地发现新类。

开放集和开放世界的识别。开放集识别(Scheirer等人,2012;Geng等人,2020;Bendale & Boult,2016;Ge等人,2017;Sun等人,2020a)考虑了在测试过程中可能出现的新类的归纳设置,并且模型需要拒绝来自新类的实例。为了将这些方法扩展到开放世界的设置中,我们包括一个在拒绝的实例上发现类的基线。然而,结果表明,这种方法不能有效地解决开放世界SSL的挑战。同样,开放世界的识别方法(Bendale & Boult,2015;Rudd等人,2017;Boult等人,2019)要求系统逐步学习并以新的类别扩展已知的类别集。这些方法通过人在回路中逐步标注新的类别。相比之下,开放世界SSL在学习阶段利用未标记的数据,不需要人在回路中。

广义零点学习(GZSL)。与开放世界的SSL一样,GZSL(Xian等人,2017;Liu等人,2018;Chao等人,2016)假设在标记集中看到的类和新的类在测试时存在。然而,GZSL对先验知识的可用性提出了额外的假设,这些先验知识作为辅助属性唯一地描述了每个单独的类,包括新类。这一限制性假设严重限制了GZSL方法在实践中的应用。相比之下,开放世界SSL更为普遍,因为它不假设任何关于类的先验信息。 

 提出方法

在本节中,我们首先定义了开放世界的SSL环境。接着,我们概述了ORCA框架,然后详细介绍了我们框架的每个组成部分。

3.1 开放世界的半监督学习设置

在开放世界的SSL中,我们假设有一个过渡性的学习环境,在这个环境中,数据集的标记部分Dl = f(xi; yi)gm i=1和数据集的未标记部分Du = f(xi)gn i=1在输入端被给出。我们分别把在标记数据中看到的类的集合表示为Cl,把未标记的测试数据中的类的集合表示为Cu。在开放世界SSL中,我们假设类别/类的转移,即Cl\Cu 6= ;和Cl 6= Cu。我们认为Cs=Cl\Cu是一个已见类的集合,Cn=CunCl是一个新类的集合。

定义1(开放世界的SSL)。在开放世界的SSL中,模型需要将Du中的实例分配到先前看到的类Cs中,或者形成一个新的类c 2 Cn并将数据点分配给它。 注意开放世界的SSL概括了新的类发现和传统的(封闭世界)SSL。 新的类发现假设在有标签和无标签的数据中的类是不相交的,即Cl\Cu=;,而(封闭世界)SSL假设在有标签和无标签的数据中的类是相同的,即Cl = Cu。

3.2 ORCA的概述 

解决开放世界SSL的关键挑战是既要学习已见/标记的类,也要学习未见/未标记的类。然而,这是一个挑战,因为与新的类相比,模型在所见的类上学习鉴别性表征的速度更快。这就导致所见类的类内方差比新颖类的小。为了规避这个问题,我们提出了ORCA,一种在训练过程中利用不确定性适应性余量减少所见类和新类的类内方差的方法。ORCA的关键观点是利用未标记数据的不确定性来控制所见类的类内方差--如果未标记数据的不确定性很高,我们将执行较大的所见类的类内方差,以减少所见类和新类方差之间的差距,而如果不确定性很低,我们将执行较小的类内方差,以鼓励模型充分使用标记的数据。通过这种方式,我们用不确定性自适应余量来控制所见类的类内方差,并确保所见类的鉴别性表征与新颖类相比不会学习得太快。

给定有标签的实例Xl = fxi 2 RNgn i=1和无标签的实例Xu = fxi 2 RNgm i=1,ORCA首先应用嵌入函数fθ : RN ! RD来获得特征表示Zl = fzi 2 RDgn i=1和Zu = fzi 2 RDgm i=1,分别用于有标签和无标签的数据。这里,zi = fθ(xi),对于每个实例xi 2 Xl [Xu. 在骨干网络的顶部,我们添加一个分类头,由一个单一的线性层组成,参数为权重矩阵W:RD ! RjCl[Cuj,然后是一个softmax层。请注意,分类头的数量被设定为以前看到的类的数量和新类的预期数量。因此,首先jClj头将实例分类到以前见过的类别之一,而其余的头则将实例分配到新的类别中。最终的类/群预测被计算为ci = argmax(WT - zi) 2 R. 如果ci 2 C = l,那么xi属于新的类。新颖类的数量jCuj可以是已知的,并作为算法的输入给出,这是聚类和新颖类发现方法的典型假设。然而,如果不提前知道新类的数量,我们可以用大量的预测头/新类来初始化ORCA。然后,ORCA的目标函数通过不把任何实例分配给不需要的预测头来推断出类的数量,因此这些头永远不会激活。ORCA的目标函数结合了三个部分(图2)。(i)带有不确定性适应余量的监督目标,(ii)成对目标和(iii)规范化项。

其中LS表示监督目标,LP表示成对目标,R是正则化。 η1和η2是正则化参数,在我们所有的实验中设置为1。算法的伪代码在附录B的算法1中进行了总结。我们在附录C中报告了对正则化参数的敏感性分析,接下来讨论每个目标的细节。

3.3 带有不确定性适应性余量的监督目标 

首先,具有不确定性的自适应余量的监督目标迫使网络将实例正确地分配到以前见过的类别中,但控制学习这一任务的速度,以允许同时学习形成新类别的集群。我们利用标记数据的分类注释fyign i=1,优化权重W和骨架θ。分类注释可以通过使用标准的交叉熵(CE)损失作为监督目标加以利用。

然而,在标记数据上使用标准的交叉熵损失会在已见类和新见类之间产生不平衡问题,即梯度对已见类Cs进行更新,但对新见类Cn不更新。这可能会导致对所见类学习一个幅度较大的分类器(Kang等人,2019),导致整个模型偏向所见类。为了克服这个问题,我们引入了一个不确定的自适应余量,并建议将对数归一化,我们将在接下来描述。 

一个关键的挑战是,由于监督目标的存在,所看到的类学得更快,因此与新的类相比,它们往往具有较小的类内方差(Liu等人,2020)。成对目标通过对特征空间中的距离进行排序,为未标记的数据生成伪标签,因此类内方差的不平衡将导致容易出错的伪标签的产生。换句话说,来自新类的实例将被分配到所看到的类。为了减轻这种偏差,我们建议使用一个自适应的余量来减少所见类和新类的类内方差之间的差距。直观地说,在训练开始时,我们要执行一个较大的负边际,以鼓励所看到的类相对于新类的类内方差同样大。在训练接近尾声时,当新类的聚类已经形成时,我们将余量项调整为接近0,这样标记的数据就可以被模型充分地利用,也就是说,目标归结为公式(2)中定义的标准交叉熵。在我们的框架中,带有不确定性的自适应边际的监督目标定义如下。

 其中u¯是不确定性(下文将进一步解释),λ是定义其强度的正则器。在我们所有的实验中,我们把λ设置为1。我们在附录C中展示了这个参数的稳健性。参数s是额外的缩放参数,控制交叉熵损失的温度,在所有实验中设置为10(Wang等人,2018)。该设计与AM-Softmax(Wang等人,2018)有关。

我们建议用不确定性来捕捉类内差异。为了估计不确定性u¯,我们依靠从softmax函数的输出计算出的未标记实例的置信度。在二元设置中,u¯ = jD1uj Px2Du Var(Y jX = x) = jD1uj Px2Du Pr(Y = 1jX) - Pr(Y = 0jX),可以进一步近似为。

 我们用同样的公式作为多类环境下群体不确定性的近似值。为了正确调整余量,我们需要约束分类器的大小,因为分类器的无约束大小会对余量的调整产生负面影响。为了避免这个问题,我们把线性分类器的输入和权重归一,即zi = jz zi ij,Wj = jW Wj jj。

3.4 成对目标

其次,成对目标学习预测成对实例之间的相似性,从而使来自同一类别的实例被归为一组。这一部分目标为未标记的数据生成伪标签,以指导训练。通过使用不确定性自适应余量控制已见类和新类的类内差异,ORCA提高了伪标签的质量。

我们将聚类学习问题转化为成对相似性预测任务(Hsu等人,2018;Chang等人,2017)。考虑到有标签的数据集Xl和无标签的数据集Xu,我们的目标是微调我们的骨干fθ,并学习一个由线性分类器W参数化的相似性预测函数,从而使来自同一类别的实例被分组在一起。为了实现这一目标,我们依靠来自标签集的真实注释和在无标签集上生成的伪标签。具体来说,对于有标签的集合,我们已经知道哪些对应该属于同一类别,所以我们可以使用地面真实的标签。为了获得未标注集的伪标签,我们在一个小批次中计算所有特征表示对zi之间的余弦距离。然后,我们对计算出的距离进行排序,并为每个实例生成其最相似邻居的伪标签。因此,我们只从小型批次中的每个实例的最有把握的正面对中生成伪标签。对于迷你批中的特征表示Zl [ Zu,我们将其最接近的集合表示为Zl0 [ Zu0。请注意,Zl0总是正确的,因为它是使用地面真实标签生成的。ORCA中的成对目标被定义为二元交叉熵损失(BCE)的改进形式。

这里,σ表示softmax函数,它将实例分配到已见或新的类别之一。对于有标签的实例,我们使用地面真实注释来计算目标。对于无标签的实例,我们根据生成的伪标签来计算目标。我们只考虑最自信的配对来生成伪标签,因为我们发现伪标签中增加的噪音对集群学习是不利的。使用不确定性自适应余量,我们控制了已见类和新类的方差,从而提高了生成的伪标签的质量。与(Hsu等人,2018;Han等人,2020;Chang等人,2017)不同的是,我们只考虑正面的配对,我们发现在我们的目标中包括负面的配对并不有利于学习,因为大多数负面的配对可以很容易地识别。我们只考虑正数对的目标与(Van Gansbeke等人,2020)有关。然而,我们以在线方式更新距离和正数对,因此在训练期间受益于改进的特征表示。

3.5 正则化项

最后,正则化避免了将所有实例分配到同一类别的简单解决方案。在训练的早期阶段,网络可能会退化到一个简单的解决方案,即所有的实例都被分配到一个类别,即jCuj = 1。我们通过引入KullbackLeibler(KL)发散项来阻止这种解决方案,该发散项将Pr(yjx 2 Dl [ Du)规整为接近于标签y的先验概率分布P。

其中σ表示softmax函数。如同(Tanaka等人,2018;Arazo等人,2020;Van Gansbeke等人,2020),我们在实验中假设先验概率为均匀分布。这个术语对应于最大熵正则化,它被用于基于伪标签的SSL(Arazo等人,2020)、深度聚类方法(Van Gansbeke等人,2020)和噪声标签的训练(Tanaka等人,2018),以防止类别分布过于平坦。

4.实验

4.1实验设置

数据集。我们在四个不同的数据集上评估ORCA,包括三个标准的基准图像分类数据集CIFAR-10、CIFAR-100(Krizhevsky,2009)和ImageNet(Russakovsky等人,2015),以及一个来自生物学领域的高度不平衡的单细胞小鼠细胞图谱Tabula Muris Senis(Consortium等人,2020)。对于单细胞数据集,我们考虑了一个现实的跨组织细胞类型注释任务,其中未标记的数据与标记的数据相比来自不同的组织(Brbic等人,2020)(详情见附录B)。对于ImageNet数据集,我们按照(Van Gansbeke等人,2020)对100个类进行了细分。在所有的数据集上,我们使用可控的未标记数据和新类的比例。我们首先将类分为50%的可见类和50%的新颖类。我们在附录C中展示了不同比例的结果。然后,我们选择50%的已见类作为标记的数据集,其余的作为未标记的数据集。我们在附录C中展示了只有10%的标记样本的结果。

基线。鉴于开放世界的SSL是一个新的环境,没有现成可用的基线。因此,我们将新的类发现、SSL和开放集识别方法扩展到开放世界的SSL环境。新的类发现方法不能识别已见的类,即把未标记的数据集中的类与先前标记的数据集中已见的类相匹配。我们报告了他们在新类上的表现,并通过以下方式将这些方法扩展到适用于已见类。我们将看到的类视为新的(这些方法有效地对未标记的数据进行了聚类),并通过使用匈牙利算法将一些发现的类与标记数据中的类相匹配来报告看到的类的性能。我们考虑两种方法。DTC(Han等人,2019)和RankStats(Han等人,2020)。另一方面,传统的SSL和开放集识别(OSR)方法不能发现新的类。因此,我们通过以下方式扩展SSL和OSR方法,使其适用于新型类。我们使用SSL/OSR将点分类到已知的类中,并估计出分布外(OOD)的样本。我们报告他们在已见类上的表现,然后我们将K-means聚类(Lloyd,1982)应用于OOD样本以获得聚类(新类)。通过这种方式,我们将两种SSL方法适应于开放世界的SSL环境。深度安全SSL(DS3L)(Guo等人,2020)和FixMatch(Sohn等人,2020),以及最近的深度学习OSR方法CGDL(Sun等人,2020a)。CGDL自动拒绝OOD样本。DS3L通过给OOD样本分配低权重来考虑未标记数据中的新类。为了扩展该方法,我们以最低的权重对样本进行聚类。对于FixMatch,我们根据softmax置信度分数来估计OOD样本。对于这两种SSL方法,我们使用看到的和新的类的分区的基础事实信息来确定OOD样本的阈值。在图像数据集上,我们用SimCLR预训练所有的新类发现和SSL基线,以确保ORCA的好处不是由于预训练造成的。唯一的例外是DTC,它在标记的数据上有自己专门的预训练程序(Han等人,2019)。作为一个额外的基线,我们对SimCLR预训练后得到的表示进行了K-means聚类(Chen等人,2020a)。我们还进行了广泛的消融研究,以评估ORCA的方法的好处。具体来说,我们包括与基线的比较,其中监督目标中的自适应边际交叉熵损失被替换为标准交叉熵损失,即零边际(ZM)方法。此外,为了评估自适应边际的效果,我们将ORCA与固定的负边际(FNM)进行比较。我们发现0:5的余量值达到了最好的性能(附录C),我们在实验中使用这个值。我们将第一条基线命名为ORCA-ZM,第二条基线命名为ORCA-FNM。其他实施和实验细节可在附录中找到。

4.2结果

 

5.结论

我们引入了开放世界的SSL设置,其中新的类可以出现在未标记的测试数据中。在这种情况下,模型需要将实例分配到以前在标记数据中遇到的类中,或者通过对类似实例的分组形成新的类。为了解决这个问题,我们提出了ORCA方法,该方法利用不确定性自适应余量有效地减轻了对已见类的偏见,在训练期间控制已见类和新类的类内差异。我们在图像和单细胞数据集上的广泛实验表明,ORCA有效地解决了这个问题,并以很大的幅度超过了其他基线。我们的工作主张从传统的封闭世界设置转向更现实的机器学习模型的开放世界评估。 

 

  • 5
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值