基于元学习的半监督小样本分类
引用:Ren, Mengye, et al. “Meta-learning for semi-supervised few-shot classification.” arXiv preprint arXiv:1803.00676 (2018).
论文地址:下载地址
Abstract
在少样本分类(few-shot classification)中,我们关注的是从少量标注样本中学习分类器的算法。近年来,少样本分类的进展主要体现在元学习(meta-learning)方面,在元学习中,定义并训练了一个参数化的模型来表示学习算法,该模型在代表不同分类问题的训练集中进行训练,每个训练集包含少量的标注样本和相应的测试集。在本工作中,我们将这一少样本分类范式推进到一个新场景,其中每个训练集中还包含未标注的样本。我们考虑了两种情况:一种是所有未标注样本都假定属于与该训练集中标注样本相同的类别集合;另一种则是更具挑战性的情况,其中还提供了来自其他“干扰”类别的样本。为了解决这一范式,我们提出了对原型网络(Prototypical Networks)的新扩展,增强了其在生成原型时利用未标注样本的能力。这些模型在训练时采用端到端的方式,通过在每个训练集中学习,成功地利用未标注样本。我们在Omniglot和miniImageNet的版本上对这些方法进行了评估,这些版本经过了修改以适应这种新框架,其中包含了未标注样本。我们还提出了ImageNet的新划分,这个划分由大量类别组成,并且具有层次结构。实验结果表明,我们的原型网络能够通过未标注样本改善预测,类似于半监督学习算法的效果。
1 INTRODUCTION
大规模标注数据的可用性使得深度学习方法在多个与人工智能相关的任务中取得了令人瞩目的突破,如语音识别、物体识别和机器翻译。然而,当前的深度学习方法在处理标注数据稀缺的问题时表现不佳。具体来说,虽然当前的方法在处理单一问题(具有大量标注数据)时表现出色,但缺乏能够同时解决多种具有少量标注的分类问题的方法。而人类则能够迅速学习新的类别,例如当我们访问热带国家时,能很快识别新的水果种类。这种人类与机器学习之间的显著差距为深度学习的发展提供了肥沃的土壤。
因此,近年来出现了越来越多的关于少样本学习(few-shot learning)的研究,少样本学习旨在设计能够更好地泛化到具有小规模标注训练集问题的学习算法。在这里,我们重点讨论少样本分类问题,其中假设给定的分类问题每个类别只有少量标注样本。少样本学习的一种方法采用元学习(meta-learning)的方法,通过从大量可用标注数据生成的各种分类问题池进行迁移学习,应用到训练时未见过的类别的新分类问题中。元学习可以采取学习共享度量,为少样本分类器学习共同初始化。
这些不同的元学习形式最近在少样本分类中取得了显著进展。然而,这些进展在每个少样本学习情境的设定中有所限制,这与人类学习新概念的方式在许多方面存在差异。本文旨在通过两种方式来概括这一设定。首先,我们考虑在存在额外未标注数据的情况下学习新类别的情境。虽然半监督学习在单一分类任务的常规设置中取得了许多成功应用,其中训练和测试时的类别相同,但这些工作并没有解决将学习迁移到训练时未见过的新类别的问题,这是我们在本文中所探讨的。其次,我们考虑学习新类别时,这些类别不是孤立地呈现。相反,许多未标注的样本来自不同的类别;这些“干扰”类别的存在为少样本问题引入了额外且更现实的难度。这项工作是首次研究这种具有挑战性的半监督少样本学习形式。
首先,我们定义了这一问题并提出了适用于评估的基准,这些基准是从普通少样本学习中使用的Omniglot和miniImageNet基准中改编而来。我们对上述两种情境进行了广泛的实证研究,分别是有干扰类别和没有干扰类别的情况。其次,我们提出并研究了三种新型的原型网络扩展方法,将其应用到半监督情境中,原型网络是当前少样本学习的最先进方法。最后,我们在实验中证明了我们的半监督变体能够成功地利用未标注样本,并超越了纯监督的原型网络。
图 1:设定一个场景,其目标是学习一个分类器来区分两个先前未见过的类别(金鱼和鲨鱼),不仅有这两个类别的标注样本,还包含一个更大的未标注样本池,其中部分样本可能属于目标类别中的一个。在本研究中,我们旨在更接近这一更加自然的学习框架,通过在学习过程中引入来自目标类别(用虚线红色边框表示)的未标注数据,以及来自干扰类别的未标注数据。
2 BACKGROUND
我们首先准确定义当前少样本学习的范式,以及原型网络(Prototypical Network)在解决这一问题中的方法。
2.1 FEW-SHOT LEARNING
近期在少样本学习方面的进展得益于采用了一个基于情境(episodic)的方法。考虑一种情况,我们有一个大的标注数据集,包含类别集 C train C_{\text{train}} Ctrain。然而,在基于 C train C_{\text{train}} Ctrain 中的样本进行训练后,我们的最终目标是为一个不重叠的新类别集 C test C_{\text{test}} Ctest 生成分类器,对于这些新类别,只有少量标注样本可用。情境方法的核心思想是模拟测试时将遇到的少样本问题,充分利用 C train C_{\text{train}} Ctrain 类别的大量可用标注数据。
具体来说,模型在 K-shot, N-way 的情境上进行训练,这些情境首先从 C train C_{\text{train}} Ctrain 中采样一个包含 N N N 类的小子集,然后生成:
- 一个训练(支持)集 S = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , … , ( x N × K , y N × K ) } S = \{(x_1, y_1), (x_2, y_2), \dots, (x_{N \times K}, y_{N \times K})\} S={(x1,y1),(x2,y2),…,(xN×K,yN×K)},包含来自每个类别的 K K K 个样本;
- 一个测试(查询)集 Q = { ( x 1 ∗ , y 1 ∗ ) , ( x 2 ∗ , y 2 ∗ ) , … , ( x T ∗ , y T ∗ ) } Q = \{(x^*_1, y^*_1), (x^*_2, y^*_2), \dots, (x^*_T, y^*_T)\} Q={(x1∗,y1∗),(x2∗,y2∗),…,(xT∗,yT∗)},其中包含来自同样 N N N 类的不同样本。
每个 x i ∈ R D x_i \in \mathbb{R}^D xi∈RD 是一个维度为 D D D 的输入向量, y i ∈ { 1 , 2 , … , N } y_i \in \{1, 2, \dots, N\} yi∈{1,2,…,N} 是类别标签(对 x i ∗ x^*_i xi∗ 和 y i ∗ y^*_i yi∗ 也类似)。训练时,将支持集 S S S 输入模型,并更新模型参数,最小化模型在查询集 Q Q Q 中样本的预测误差。
一种理解这种方法的方式是,我们的模型实际上是在训练成一个优秀的学习算法。事实上,和学习算法一样,模型必须接收一组标注样本,并生成一个可以应用到新样本的预测器。此外,训练直接促使模型产生的分类器在查询集的新的样本上具有良好的泛化能力。由于这个类比,这种方法的训练通常被称为“学习学习”或“元学习”(meta-learning)。
另一方面,将情境中的内容称为训练集和测试集,并将这些情境的学习过程称为元学习或元训练(如某些文献中所做)可能会引起混淆。因此,为了清晰起见,我们将情境中的内容称为支持集和查询集,并将遍历训练情境的过程简单地称为训练。
2.2 PROTOTYPICAL NETWORKS
原型网络是一个少样本学习模型,具有简单且取得最先进性能的优点。高层次上,它利用支持集 S S S 从每个类别中提取原型向量,并根据输入样本与每个类别原型的距离对查询集中的样本进行分类。
更精确地说,原型网络学习一个嵌入函数 h ( x ) h(x) h(x),该函数是一个神经网络,用于将样本映射到一个空间,在该空间中,同一类别的样本彼此接近,而不同类别的样本则彼此远离。原型网络的所有参数都位于嵌入函数中。
为了计算每个类别
c
c
c 的原型
p
c
p_c
pc,执行对每个类别嵌入样本的平均操作:
p
c
=
∑
i
h
(
x
i
)
z
i
,
c
∑
i
z
i
,
c
z
i
,
c
=
1
[
y
i
=
c
]
(1)
p_c = \frac{\sum_i h(x_i) z_{i,c}}{\sum_i z_{i,c}} z_{i,c} = 1[y_i = c] \tag{1}
pc=∑izi,c∑ih(xi)zi,czi,c=1[yi=c](1)
这些原型定义了对任何新(查询)样本
x
∗
x^*
x∗ 的预测器,该预测器根据样本
x
∗
x^*
x∗ 与每个原型之间的距离,为每个类别
c
c
c 分配一个概率,如下所示:
p
(
c
∣
x
∗
,
{
p
c
}
)
=
exp
(
−
∥
h
(
x
∗
)
−
p
c
∥
2
2
)
∑
c
′
exp
(
−
∥
h
(
x
∗
)
−
p
c
′
∥
2
2
)
(2)
p(c|x^*, \{p_c\}) = \frac{\exp(-\|h(x^*) - p_c\|_2^2)}{\sum_{c'} \exp(-\|h(x^*) - p_{c'}\|_2^2)} \tag{2}
p(c∣x∗,{pc})=∑c′exp(−∥h(x∗)−pc′∥22)exp(−∥h(x∗)−pc∥22)(2)
用于更新原型网络的损失函数非常简单,它是所有查询样本正确类别分配的平均负对数概率:
−
1
T
∑
i
log
p
(
y
i
∗
∣
x
i
∗
,
{
p
c
}
)
(3)
-\frac{1}{T} \sum_i \log p(y^*_i | x^*_i, \{p_c\}) \tag{3}
−T1i∑logp(yi∗∣xi∗,{pc})(3)
训练通过最小化平均损失进行,在训练情境上迭代,并对每个情境执行梯度下降更新。
泛化性能通过测试集情境来衡量,这些情境包含来自类别 C test C_{\text{test}} Ctest 的图像,而不是 C train C_{\text{train}} Ctrain。对于每个测试情境,我们使用原型网络为提供的支持集 S S S 生成的预测器,将查询输入 x ∗ x^* x∗ 分类到最可能的类别 y ^ = arg max c p ( c ∣ x ∗ , { p c } ) \hat{y} = \arg\max_c p(c|x^*, \{p_c\}) y^=argmaxcp(c∣x∗,{pc})。
3 SEMI-SUPERVISED FEW-SHOT LEARNING
我们现在定义本文中考虑的少样本学习的半监督设置。
训练集表示为标注样本和未标注样本的元组:(S, R)。标注部分是少样本学习文献中的常见支持集 S S S,包含一系列输入和目标的元组。除了经典的少样本学习外,我们还引入了一个仅包含输入的未标注集 R R R: R = { x ~ 1 , x ~ 2 , … , x ~ M } R = \{ \tilde{x}_1, \tilde{x}_2, \dots, \tilde{x}_M \} R={x~1,x~2,…,x~M}。 与纯监督设置相似,我们的模型在训练时也需要在预测情境查询集 Q Q Q 中样本的标签时表现良好。图 2 显示了训练和测试情境的可视化。
图 2:半监督少样本学习设置示例。训练过程中,我们通过训练情境进行迭代,每个情境由一个支持集
S
S
S、一个未标注集
R
R
R 和一个查询集
Q
Q
Q 组成。目标是使用
S
S
S 中的标注样本(以其数字类别标签表示)和
R
R
R 中的未标注样本,在每个情境中进行泛化,以便在相应的查询集上取得良好的性能。
R
R
R 中的未标注样本可能与我们正在考虑的类别相关(上图中用绿色加号表示),或者它们可能是干扰样本,属于与当前情境无关的类别(上图中用红色减号表示)。然而,请注意,模型实际上并没有关于每个未标注样本是否为干扰项的真实标签信息;加号/减号符号仅用于说明目的。在测试时,我们会得到新的情境,这些情境包含在训练时未见过的新类别,我们使用这些情境来评估元学习方法。
3.1 半监督原型网络
在原始的原型网络(Prototypical Networks)中,并未指定如何利用未标注集 R R R。在接下来的部分,我们将提出多种扩展方法,从原始的原型定义 p c p_c pc 出发,并提供一个使用未标注样本 R R R 来生成精炼原型 p ~ c \tilde{p}_c p~c 的过程。
在获得精炼的原型后,每个模型都使用与普通原型网络(Prototypical Networks)相同的损失函数进行训练,如公式 3 所示,但将 p c p_c pc 替换为 p ~ c \tilde{p}_c p~c。也就是说,每个查询样本根据其嵌入位置与相应精炼原型的接近程度被分类为 N N N 类中的一个,并使用正确分类的平均负对数概率进行训练。
图 3:左图:原型网络中的原型初始化基于对应类别样本的均值位置,类似于普通的原型网络。支持集、未标注集和查询集的样本分别用实线、虚线和白色边框表示。右图:通过结合未标注样本得到的精炼原型,这些原型能够正确分类所有查询样本。
3.1.1 基于软 k k k-均值的原型网络
我们首先考虑通过借鉴半监督聚类的一种简单方式来利用未标注样本来精炼原型。将每个原型视为一个聚类中心,精炼过程可以尝试调整聚类位置,使其更好地适应支持集和未标注集中的样本。在这种视角下,支持集中的标注样本的聚类分配被视为已知,并且固定为每个样本的标签。精炼过程必须估计未标注样本的聚类分配,并相应地调整聚类位置(即原型)。
一个自然的选择是借鉴软K-means推断的方式。我们偏好这种K-means版本而非硬分配,因为硬分配会使得推断过程不可微分。我们首先使用常规原型网络的原型 p c p_c pc(如公式 1 所示)作为聚类位置。然后,未标注样本根据它们与聚类位置的欧几里得距离,得到对每个聚类的部分分配 ( z ~ j , c \tilde{z}_{j,c} z~j,c)。最后,通过结合这些未标注样本,得到精炼后的原型。
这一过程可以总结如下:
p
~
c
=
∑
i
h
(
x
i
)
z
i
,
c
+
∑
j
h
(
x
~
j
)
z
~
j
,
c
∑
i
z
i
,
c
+
∑
j
z
~
j
,
c
,
z
~
j
,
c
=
exp
(
−
∥
h
(
x
~
j
)
−
p
c
∥
2
2
)
∑
c
′
exp
(
−
∥
h
(
x
~
j
)
−
p
c
′
∥
2
2
)
(4)
\tilde{p}_c = \frac{\sum_i h(x_i) z_{i,c} + \sum_j h(\tilde{x}_j) \tilde{z}_{j,c}}{\sum_i z_{i,c} + \sum_j \tilde{z}_{j,c}} , \tilde{z}_{j,c} = \frac{\exp \left(-\|h(\tilde{x}_j) - p_c\|_2^2 \right)}{\sum_{c'} \exp \left(-\|h(\tilde{x}_j) - p_{c'}\|_2^2 \right)} \tag{4}
p~c=∑izi,c+∑jz~j,c∑ih(xi)zi,c+∑jh(x~j)z~j,c,z~j,c=∑c′exp(−∥h(x~j)−pc′∥22)exp(−∥h(x~j)−pc∥22)(4)
然后,每个查询输入的类别预测按照公式 2 的方式进行建模,但使用精炼后的原型
p
~
c
\tilde{p}_c
p~c。
我们可以像K-means一样执行多次精炼迭代。然而,我们尝试了多次迭代,并发现结果在单次精炼步骤之后并未得到进一步改善。
3.1.2 带有干扰类簇的软 k k k-均值原型网络
上述的软K-means方法隐式地假设每个未标注样本都属于情境中的 N N N 个类别之一。然而,如果不做这种假设,并让模型对来自其他类别的样本(我们称之为“干扰类别”)具有鲁棒性,这将更加通用。例如,如果我们想区分独轮车和滑板车的图片,并决定通过从网络下载图像来添加未标注集。那么,假设这些图像全都是独轮车或滑板车的图片是不现实的。即使进行了集中搜索,某些图像也可能来自类似类别,例如自行车。
由于软K-means将其软分配分布到所有类别上,干扰项可能是有害的,并会干扰精炼过程,因为原型也会被调整,以部分考虑这些干扰项。解决这个问题的一个简单方法是添加一个额外的聚类,其目的是捕捉干扰项,从而防止它们污染目标类别的聚类:
p
c
=
{
∑
i
h
(
x
i
)
z
i
,
c
∑
i
z
i
,
c
for
c
=
1
,
…
,
N
0
for
c
=
N
+
1
(5)
p_c = \begin{cases} \frac{\sum_i h(x_i) z_{i,c}}{\sum_i z_{i,c}} & \text{for } c = 1, \dots, N \\ 0 & \text{for } c = N + 1 \end{cases} \tag{5}
pc={∑izi,c∑ih(xi)zi,c0for c=1,…,Nfor c=N+1(5)
这里我们做一个简化假设,即干扰聚类的原型位于原点。我们还考虑引入长度尺度
r
c
r_c
rc 来表示聚类内部距离的变化,特别是对于干扰聚类:
z
~
j
,
c
=
exp
(
−
1
r
c
2
∥
x
~
j
−
p
c
∥
2
2
−
A
(
r
c
)
)
∑
c
′
exp
(
−
1
r
c
2
∥
x
~
j
−
p
c
′
∥
2
2
−
A
(
r
c
′
)
)
,
A
(
r
)
=
1
2
log
(
2
π
)
+
log
(
r
)
(6)
\tilde{z}_{j,c} = \frac{\exp \left( -\frac{1}{r_c^2} \| \tilde{x}_j - p_c \|_2^2 - A(r_c) \right)}{\sum_{c'} \exp \left( -\frac{1}{r_c^2} \| \tilde{x}_j - p_{c'} \|_2^2 - A(r_{c'}) \right)} , A(r) = \frac{1}{2} \log(2\pi) + \log(r) \tag{6}
z~j,c=∑c′exp(−rc21∥x~j−pc′∥22−A(rc′))exp(−rc21∥x~j−pc∥22−A(rc)),A(r)=21log(2π)+log(r)(6)
为了简化,我们在实验中将
r
1
,
…
,
r
N
r_1, \dots, r_N
r1,…,rN 设置为 1,只学习干扰聚类的长度尺度
r
N
+
1
r_{N+1}
rN+1。
3.1.3 基于软 k k k-均值与掩码的原型网络
使用单一聚类来建模干扰未标注样本可能过于简单。实际上,这与我们的假设(即每个聚类对应一个类别)不一致,因为干扰样本可能涉及不止一个自然物体类别。继续以独轮车和自行车为例,我们对未标注图像的网络搜索可能不仅包含自行车,还可能包括其他相关物体,如三轮车或汽车。我们的实验也反映了这一点,在实验中我们构造了情境生成过程,使其从多个类别中采样干扰样本。
为了解决这个问题,我们提出了一种改进的变体:不是用一个高方差的“备选”聚类来捕捉干扰项,而是将干扰项建模为不属于任何合法类别原型的某个区域内的样本。这是通过在未标注样本的贡献上加入软掩码机制来实现的。从高层次上看,我们希望距离原型较近的未标注样本被掩码的程度小于那些距离较远的样本。
更具体地,我们对软K-means精炼进行如下修改。我们首先计算样本
x
~
j
\tilde{x}_j
x~j 和原型
p
c
p_c
pc 之间的归一化距离
d
~
j
,
c
\tilde{d}_{j,c}
d~j,c:
d
~
j
,
c
=
d
j
,
c
1
/
M
∑
j
d
j
,
c
,
d
j
,
c
=
∥
h
(
x
~
j
)
−
p
c
∥
2
2
(7)
\tilde{d}_{j,c} = \frac{d_{j,c}}{1/M \sum_j d_{j,c}} , d_{j,c} = \| h(\tilde{x}_j) - p_c \|_2^2 \tag{7}
d~j,c=1/M∑jdj,cdj,c,dj,c=∥h(x~j)−pc∥22(7)
然后,通过将原型的归一化距离的各种统计量输入小型神经网络,预测每个原型的软阈值
β
c
\beta_c
βc 和斜率
γ
c
\gamma_c
γc:
[
β
c
,
γ
c
]
=
MLP
(
[
min
j
(
d
~
j
,
c
)
,
max
j
(
d
~
j
,
c
)
,
var
j
(
d
~
j
,
c
)
,
skew
j
(
d
~
j
,
c
)
,
kurt
j
(
d
~
j
,
c
)
]
)
(8)
[\beta_c, \gamma_c] = \text{MLP} \left([ \min_j (\tilde{d}_{j,c}), \max_j (\tilde{d}_{j,c}), \text{var}_j (\tilde{d}_{j,c}), \text{skew}_j (\tilde{d}_{j,c}), \text{kurt}_j (\tilde{d}_{j,c}) ] \right) \tag{8}
[βc,γc]=MLP([jmin(d~j,c),jmax(d~j,c),varj(d~j,c),skewj(d~j,c),kurtj(d~j,c)])(8)
这允许每个阈值使用关于聚类内变异量的信息来决定它应该如何激进地去除未标注样本。接下来,通过将归一化距离与阈值进行比较,计算每个样本对每个原型的贡献的软掩码
m
j
,
c
m_{j,c}
mj,c,如下所示:
p
~
c
=
∑
i
h
(
x
i
)
z
i
,
c
+
∑
j
h
(
x
~
j
)
z
~
j
,
c
m
j
,
c
∑
i
z
i
,
c
+
∑
j
z
~
j
,
c
m
j
,
c
,
m
j
,
c
=
σ
(
−
γ
c
(
d
~
j
,
c
−
β
c
)
)
(9)
\tilde{p}_c = \frac{\sum_i h(x_i) z_{i,c} + \sum_j h(\tilde{x}_j) \tilde{z}_{j,c} m_{j,c}}{\sum_i z_{i,c} + \sum_j \tilde{z}_{j,c} m_{j,c}} , m_{j,c} = \sigma \left( -\gamma_c (\tilde{d}_{j,c} - \beta_c) \right) \tag{9}
p~c=∑izi,c+∑jz~j,cmj,c∑ih(xi)zi,c+∑jh(x~j)z~j,cmj,c,mj,c=σ(−γc(d~j,c−βc))(9)
其中
σ
(
⋅
)
\sigma(\cdot)
σ(⋅) 是 sigmoid 函数。
在使用这一精炼过程进行训练时,模型现在可以通过公式 (8) 中的 MLP 学习是否完全忽略某些未标注样本或将其包含在内。使用软掩码使得这个过程完全可微分。最后,就像常规软K-means一样(无论是否有干扰聚类),尽管我们可以递归地执行多个精炼步骤,但我们发现单步精炼已经足够有效。
4 RELATED WORK
在这里,我们总结了与少样本学习、半监督学习和聚类相关的最重要的文献工作。
目前在少样本学习中表现最好的方法是使用元学习规定的情境训练框架。我们工作的研究范式属于度量学习方法。此前在少样本分类的度量学习方面的工作包括 Deep Siamese Networks、Matching Networks和 Prototypical Networks,而我们在本文中扩展了 Prototypical Networks 到半监督情境。其一般思路是学习一个嵌入函数,使得属于同一类别的样本被映射到空间中靠近的地方,而不同类别的样本则被映射到远离的地方。然后,支持集和查询集中的样本嵌入之间的距离被用作相似度度量来进行分类。最后,关于扩展少样本学习范式,Bachman 等在主动学习框架中使用 Matching Networks,其中模型可以在若干时间步骤中选择哪个未标注样本添加到支持集,然后进行查询集的分类。与我们的设置不同,他们的元学习代理可以从未标注集获取真实标签,并且没有使用干扰样本。
其他元学习方法在少样本学习中的应用包括学习如何使用支持集来更新学习模型,从而使其能在查询集上进行泛化。近期的工作涉及学习用于学习神经网络的权重初始化和/或更新步骤。另一种方法是训练一个通用神经架构,例如记忆增强型递归网络或时序卷积网络,以顺序地处理支持集并准确预测查询集样本的标签。这些方法在少样本学习中也具有竞争力,但我们选择在本文中扩展 Prototypical Networks,因为它具有简单性和高效性。
关于半监督学习的文献,尽管其内容相当广泛,但与我们工作最相关的类别是与自训练(self-training)相关的研究。在自训练中,首先在初始训练集上训练一个分类器。然后,使用该分类器对未标注的样本进行分类,并将分类器最自信的预测结果作为假定标签,将这些未标注样本添加到训练集中。这类似于我们对原型网络的软K-means扩展。实际上,由于软分配(公式 4)与常规原型网络对新输入的分类器输出(公式 2)相匹配,因此精炼过程可以看作是将来自未标注集的(软)自标签增强的新的支持集重新输入到原型网络。
我们的算法也与传导学习(transductive learning)相关,在传导学习中,基本分类器通过查看未标注样本进行精炼。实际上,可以在传导设置中使用我们的方法,其中未标注集与查询集相同;然而,为了避免我们的模型在元学习过程中记住未标注集的标签,我们将未标注集与查询集分开。
除了原始的K-means方法,与我们设置最相关的聚类算法工作涉及在存在异常值的情况下应用K-means。其目标是正确发现并忽略异常值,以防它们错误地将聚类位置偏移,从而形成错误的数据划分。这个目标在我们的设置中也非常重要,因为不忽略异常值(或干扰项)会错误地偏移原型,进而负面影响分类性能。
我们对半监督学习和聚类文献的贡献是超越经典的在单一数据集上进行训练和评估的设置,考虑到我们必须学习从训练集类别 C train C_{\text{train}} Ctrain 到新测试集类别 C test C_{\text{test}} Ctest 的迁移设置。
5 EXPERIMENTS
5.1 DATASETS
我们对半监督学习和聚类文献的贡献是超越经典的在单一数据集上进行训练和评估的设置,考虑到我们必须学习从训练集类别 我们在三个数据集上评估了模型的性能:两个基准少样本分类数据集和一个新的大规模数据集,我们希望这个数据集对于未来的少样本学习研究有所帮助。
Omniglot(Lake 等,2011)是一个包含50种字母表的1,623个手写字符的数据集。每个字符由20名人工标注者绘制。我们遵循Vinyals 等(2016)提出的少样本设置,将图像大小调整为28 × 28像素,并应用90°的旋转,最终得到总共6,492个类别。这些类别分为4,112个训练类别,688个验证类别和1,692个测试类别。
miniImageNet(Vinyals 等,2016)是对ILSVRC-12数据集(Russakovsky 等,2015)的修改版本,每个类别随机选择600张图像作为数据集的一部分。我们依赖Ravi & Larochelle(2017)使用的类别划分。这些划分使用64个类别用于训练,16个用于验证,20个用于测试。所有图像的大小为84 × 84像素。
tieredImageNet 是我们提出的用于少样本分类的数据集。与miniImageNet类似,它是ILSVRC-12的一个子集。然而,tieredImageNet代表了ILSVRC-12的一个更大子集(608个类别,而miniImageNet为100个类别)。类似于Omniglot,其中字符被分组为字母表,tieredImageNet将类别分组为对应于ImageNet层次结构中更高层节点的广义类别。总共有34个类别,每个类别包含10到30个类别。这些类别分为20个训练类别,6个验证类别和8个测试类别(数据集的详细信息可以在补充材料中找到)。这确保了所有训练类别与测试类别具有足够的区分度,不像miniImageNet和Vinyals 等(2016)提出的randImageNet等其他替代方法。例如,在Ravi & Larochelle(2017)划分的miniImageNet中,“管风琴”是一个训练类别,而“电吉他”是一个测试类别,尽管它们都是乐器。在tieredImageNet中,这种情况不会发生,因为“乐器”是一个高层类别,因此不会在训练和测试类别之间进行划分。这代表了一个更现实的少样本学习场景,因为通常我们不能假设测试类别与训练时看到的类别相似。此外,tieredImageNet的层次结构可能对能够利用类别之间层次关系的少样本学习方法有所帮助。我们将这些有趣的扩展留待未来工作中探讨。
表 1:Omniglot 1-shot 分类结果。在本表及以下表格中,“w/ D”表示“带干扰项”,其中未标注图像包含无关类别。
5.2 调整数据集以适应半监督学习
对于每个数据集,我们首先创建一个额外的划分,将每个类别的图像分为不重叠的标注集和未标注集。对于Omniglot和tieredImageNet,我们从每个类别中随机抽取10%的图像来形成标注集,其余的90%仅可用于情境中的未标注部分。对于miniImageNet,我们使用40%的数据作为标注集,其余60%作为未标注集,因为我们发现10%的数据集过小,无法取得合理的性能并避免过拟合。我们报告了在10个随机划分的标注集和未标注集上计算得到的平均分类得分,误差采用标准误差(标准差除以总划分数的平方根)计算。
我们要强调的是,由于这种标注/未标注划分,我们使用的标签信息严格少于之前在这些数据集上的工作中使用的信息。因此,我们不期望我们的结果与已发布的数字完全匹配,已发布的结果应被视为本文所定义的半监督模型性能的上限。
情境构建过程如下进行。对于给定的数据集,我们通过从训练类别集 C train C_{\text{train}} Ctrain 中随机均匀地抽取 N N N 个类别来创建训练情境。然后,我们从每个类别的标注集抽取 K K K 张图像来形成支持集,并从每个类别的未标注集抽取 M M M 张图像来形成未标注集。如果包含干扰项,我们还会从训练类别集中抽取 H H H 个其他类别,并从每个类别的未标注集抽取 M M M 张图像作为干扰项。这些干扰项图像与 N N N 个类别的未标注图像一起被加入到未标注集中(总共是 M × N + M × H M \times N + M \times H M×N+M×H 张未标注图像)。情境的查询部分由来自每个选定类别的标注集中的固定数量的图像组成。测试情境的创建方式类似,但从测试集 C test C_{\text{test}} Ctest 中抽取 N N N 个类别(可选地包括 H H H 个干扰类别)。在此实验中,我们使用了 H = N = 5 H = N = 5 H=N=5,即标注类别和干扰类别均为5个类别。我们在大多数情况下使用 M = 5 M = 5 M=5 作为训练集大小, M = 20 M = 20 M=20 作为测试集大小,从而衡量模型在更大未标注集大小上的泛化能力。数据集划分的详细信息,包括分配给训练/验证/测试集的具体类别,可以在附录A和B中找到。
在每个数据集上,我们将我们的三种半监督模型与两个基准模型进行比较。第一个基准,在我们的表格中称为“监督”,是一个普通的原型网络,使用每个数据集的标注集进行纯监督训练。第二个基准,称为“半监督推理”,使用该监督原型网络学习的嵌入函数,但在测试时使用软K-means精炼步骤对原型进行半监督精炼。与此相比,我们的半监督模型在训练时和测试时都执行精炼,因此学习一个不同的嵌入函数。我们在两种设置下评估每个模型:一种是所有未标注样本都属于感兴趣的类别,另一种是包含干扰项的更具挑战性的设置。模型超参数的详细信息可以在附录D和我们的在线仓库中找到。
5.3 RESULTS
Omniglot、miniImageNet 和 tieredImageNet 的结果分别给出在表 1、表 2 和表 5 中,而图 4 展示了我们的模型在 tieredImageNet(我们最大的数据集)上的表现,使用了不同的
M
M
M 值(每个类别中未标注集的项数)。附录 C 中提供了比较 ProtoNet 模型与这些数据集上的各种基准模型的额外结果,并对 Masked Soft k-Means 模型的表现进行了分析。
表 2:miniImageNet 1/5-shot 分类结果。
表 3:tieredImageNet 1/5-shot 分类结果。
在所有三个基准上,至少有一个我们提出的模型优于基准模型,证明了我们的半监督元学习过程的有效性。在无干扰项的设置中,所有三个提出的模型在几乎所有实验中都优于基准模型,但在不同数据集和样本数下,三种模型之间并没有明显的胜者。在训练和测试包含干扰项的情境下,Masked Soft k-Means 在所有三个数据集上表现出最稳健的性能,在每种情况下都达到了最佳结果(除一个案例外)。实际上,该模型的表现接近于基于无干扰项结果的上限。
从图 4 中,我们观察到当每个类别中的未标注集的项数从 0 增长到 25 时,测试准确率明显提高。这些模型在 M = 5 M = 5 M=5 时进行了训练,因此展示了其在泛化能力上的外推能力。这确认了通过元训练,模型学会了获取更好的表示,并通过半监督精炼得到改善。
图 4:在测试时使用不同数量的未标注样本时,模型在 tieredImageNet 上的表现。
6 CONCLUSION
在本工作中,我们提出了一种新颖的半监督少样本学习范式,其中为每个情境添加了一个未标注集。我们还将这一设置扩展到更现实的情况,其中未标注集包含与标注类别不同的新类别。为了应对当前少样本分类数据集在标注与未标注划分上过于小型且缺乏标签层次结构的问题,我们引入了一个新数据集——tieredImageNet。我们提出了几种新的原型网络扩展,并且它们在半监督设置下相较于我们的基准模型表现出一致的改进。作为未来工作,我们正在将快速权重融入我们的框架,使得在每个情境中,样本能够根据内容具有不同的嵌入表示。