小样本学习-原型网络
摘要
我们针对小样本分类问题提出了原型网络,在这一场景下要求分类器具有泛化能力。从而当出现训练集中没有的新类时只需少量新类就能有比较好的表现。原型网络学习一个度量空间,在其中可以通过计算到每个类的原型表示的距离来执行分类。与最近的小样本学习方法相比,它们反映了一种更简单的归纳偏差,这在这种有限的数据下是有益的,并取得了良好的结果。我们提供了一个分析,表明一些简单的设计可以比复杂的体系结构和元学习的方法产生更好的效果。我们进一步将原型网络扩展到0样本学习,并在CU-Birds数据集上取得了较好的效果。
1 序言
小样本学习是一项任务,其中分类器必须适应在训练中看不到的新类,只给出这些类的几个例子。一种简单的方法,比如在新数据上重新训练模型,会严重地过度拟合。虽然这个问题相当困难,但已经证明人类有能力执行一次分类,在只给出每个新类的一个例子的前提下,并且能够保持很高程度的准确性。最近的两种方法在小样本学习方面取得了重大进展。Vinyals等人提出了匹配网络,该网络使用注意机制对学习到的标记例子集(支持集)的嵌入来预测未标记点(查询集)的类。匹配网络可以解释为在嵌入空间内应用的加权最近邻分类器。值得注意的是,该模型在训练期间利用了被称为事件的采样小批,其中每个事件都被设计为通过子采样类和数据点来模拟小样本任务。事件的使用使训练问题更忠实于测试环境,从而提高了泛化性。Ravi和拉罗谢尔进一步提出了情景训练的想法,并提出了一种小样本学习的元学习方法。他们的方法包括训练一个LSTM,在给定一个事件时产生对分类器的更新,这样它将很好地推广到一个测试集。在这里,LSTM元学习者不是在多个情节中训练单个模型,而不是学习为每个情节训练一个自定义模型。
我们通过解决过拟合的关键问题来解决小样本学习的问题。由于数据非常有限,我们假设一个分类器应该有一个非常简单的归纳偏差。我们的方法,原型网络,是基于这样一种想法,即存在一个嵌入,其中点围绕着每个类的单个原型表示进行集群。为了做到这一点,我们使用神经网络学习输入到嵌入空间的非线性映射,并将一个类的原型作为其在嵌入空间中的支持集的平均值。然后,通过简单地找到最近的类原型来对嵌入式查询点执行分类。
我们采用相同的方法来处理零样本学习;在这里,每个类都带有元数据,给出了对类的高级描述,而不是少量带标记的示例。因此,我们学习将元数据嵌入到共享空间中,作为每个类的原型。与在少镜头场景中一样,通过为嵌入式查询点找到最近的类原型来执行分类。
在本文中,我们制定了针对少样本和零样本设置的原型网络。我们在一次性设置中绘制与匹配网络的联系,并分析模型中使用的底层距离函数。特别地,我们将原型网络与聚类联系起来,以便证明当使用布雷格曼散度计算距离时,使用类均值作为原型,例如平方欧几里得距离。我们发现,距离的选择是至关重要的,因为欧几里得距离大大优于更常用的余弦相似度。在几个基准测试任务上,我们表现都非常好。原型网络比最近的元学习算法更简单、更有效,这使它们成为少样本和零样本学习的一种很有吸引力的方法。
2 原型网络
2.1 符号表示
在小样本分类中,我们给出了一个N个标记例子的小支持集 S = { ( x 1 , y 1 ) , ( x 2 , y 2 ) . . . ( x n , y n ) } S = \{(x_1,y_1),(x_2,y_2)...(x_n,y_n)\} S={(x1,y1),(x2,y2)...(xn,yn)},其中每一个 x i ∈ R D x_i\in R^D xi∈RD为一个样本的 D D D维特征向量,并且数据对应的标签 y i ∈ { 1 , 2 , . . . K } , S k y_i\in \{1,2,...K\},S_k yi∈{1,2,...K},Sk是类别 k k k的有标注数据。
2.2 模型
原型网络计算了一个M维表示,
c
k
∈
R
M
c_k \in R^M
ck∈RM,我们也可将其称为原型。原型是将一类数据使用嵌入函数嵌入得到的。
f
ϕ
:
R
D
→
R
M
f_{\phi}:R^D \to R^M
fϕ:RD→RM其中
ϕ
\phi
ϕ是可以学习的参数。每个原型都是对应类的嵌入求平均得到的:
c
k
=
1
∣
S
k
∣
∑
(
x
i
,
y
i
)
∈
S
k
f
ϕ
(
x
i
)
\mathbf{c}_k=\frac{1}{\left|S_k\right|} \sum_{\left(\mathbf{x}_i, y_i\right) \in S_k} f_\phi\left(\mathbf{x}_i\right)
ck=∣Sk∣1(xi,yi)∈Sk∑fϕ(xi)
给定一个距离函数
d
:
R
M
×
R
M
∈
[
0
,
+
∞
)
d:R^M \times R^M \in [0,+\infty)
d:RM×RM∈[0,+∞)原型网络基于在嵌入空间中到原型的距离上的softmax值,为查询点x生成一个类别概率分布:
p
ϕ
(
y
=
k
∣
x
)
=
exp
(
−
d
(
f
ϕ
(
x
)
,
c
k
)
)
∑
k
′
exp
(
−
d
(
f
ϕ
(
x
)
,
c
k
′
)
)
p_\phi(y=k \mid \mathbf{x})=\frac{\exp \left(-d\left(f_{\boldsymbol{\phi}}(\mathbf{x}), \mathbf{c}_k\right)\right)}{\sum_{k^{\prime}} \exp \left(-d\left(f_\phi(\mathbf{x}), \mathbf{c}_{k^{\prime}}\right)\right)}
pϕ(y=k∣x)=∑k′exp(−d(fϕ(x),ck′))exp(−d(fϕ(x),ck))
损失函数为
J
(
ϕ
)
=
−
log
p
ϕ
(
y
=
k
∣
x
)
J(\phi)=-\log p_{\boldsymbol{\phi}}(y=k \mid \mathbf{x})
J(ϕ)=−logpϕ(y=k∣x),使用随机梯度下降法将损失最小化。训练集是通过从训练集中随机选择一个类的子集,然后在每个类中选择一个例子的子集作为支持集,剩下的子集作为查询集。伪代码如下。
2.3 原型网络作为混合密度估计
对于一类特殊的距离函数,即正则布雷格曼发散,原型网络算法等价于对具有指数族密度的支持集进行混合密度估计。正则布雷格曼散度
d
φ
d_{\varphi}
dφ的定义为:
d
φ
(
z
,
z
′
)
=
φ
(
z
)
−
φ
(
z
′
)
−
(
z
−
z
′
)
T
∇
φ
(
z
′
)
d_{\varphi}\left(\mathbf{z}, \mathbf{z}^{\prime}\right)=\varphi(\mathbf{z})-\varphi\left(\mathbf{z}^{\prime}\right)-\left(\mathbf{z}-\mathbf{z}^{\prime}\right)^T \nabla \varphi\left(\mathbf{z}^{\prime}\right)
dφ(z,z′)=φ(z)−φ(z′)−(z−z′)T∇φ(z′),
其中
φ
\varphi
φ是勒让德型的可微严格凸函数。布雷格曼发散的例子包括平方欧几里得距离
∣
∣
z
−
z
,
∣
∣
2
||z - z^{,}||^2
∣∣z−z,∣∣2和马氏距离。原型计算可以从支持集上的硬聚类来看,每个类有一个聚类,每个支持点分配给其相应的类聚类。对于布雷格曼发散的表明,到其指定点达到最小的聚类代表距离是聚类平均值。因此,当使用布雷格曼散度时,原型计算得到了给定支持集标签的最优聚类原型。此外,任何参数为
θ
\theta
θ和累积函数为
φ
\varphi
φ的正则指数族分布
p
φ
(
z
∣
θ
)
p_{\varphi}(z|\theta)
pφ(z∣θ)都可以用唯一确定的正则布雷格曼散度来表示。
p
ψ
(
z
∣
θ
)
=
exp
{
z
T
θ
−
ψ
(
θ
)
−
g
ψ
(
z
)
}
=
exp
{
−
d
φ
(
z
,
μ
(
θ
)
)
−
g
φ
(
z
)
}
p_\psi(\mathbf{z} \mid \boldsymbol{\theta})=\exp \left\{\mathbf{z}^T \boldsymbol{\theta}-\psi(\boldsymbol{\theta})-g_\psi(\mathbf{z})\right\}=\exp \left\{-d_{\varphi}(\mathbf{z}, \boldsymbol{\mu}(\boldsymbol{\theta}))-g_{\varphi}(\mathbf{z})\right\}
pψ(z∣θ)=exp{zTθ−ψ(θ)−gψ(z)}=exp{−dφ(z,μ(θ))−gφ(z)}
现在考虑一个带参数的正则指数族混合模型
Γ
=
{
θ
k
,
π
k
}
k
=
1
K
\boldsymbol{\Gamma}=\left\{\boldsymbol{\theta}_k, \pi_k\right\}_{k=1}^K
Γ={θk,πk}k=1K :
p
(
z
∣
Γ
)
=
∑
k
=
1
K
π
k
p
ψ
(
z
∣
θ
k
)
=
∑
k
=
1
K
π
k
exp
(
−
d
φ
(
z
,
μ
(
θ
k
)
)
−
g
φ
(
z
)
)
p(\mathbf{z} \mid \boldsymbol{\Gamma})=\sum_{k=1}^K \pi_k p_\psi\left(\mathbf{z} \mid \boldsymbol{\theta}_k\right)=\sum_{k=1}^K \pi_k \exp \left(-d_{\varphi}\left(\mathbf{z}, \boldsymbol{\mu}\left(\boldsymbol{\theta}_k\right)\right)-g_{\varphi}(\mathbf{z})\right)
p(z∣Γ)=k=1∑Kπkpψ(z∣θk)=k=1∑Kπkexp(−dφ(z,μ(θk))−gφ(z))
Given
Γ
\Gamma
Γ,对一个未标记的点
z
\mathbf{z}
z的类赋值
y
y
y的推断为:
p
(
y
=
k
∣
z
)
=
π
k
exp
(
−
d
φ
(
z
,
μ
(
θ
k
)
)
)
∑
k
′
π
k
′
exp
(
−
d
φ
(
z
,
μ
(
θ
k
)
)
)
p(y=k \mid \mathbf{z})=\frac{\pi_k \exp \left(-d_{\varphi}\left(\mathbf{z}, \boldsymbol{\mu}\left(\boldsymbol{\theta}_k\right)\right)\right)}{\sum_{k^{\prime}} \pi_{k^{\prime}} \exp \left(-d_{\varphi}\left(\mathbf{z}, \boldsymbol{\mu}\left(\boldsymbol{\theta}_k\right)\right)\right)}
p(y=k∣z)=∑k′πk′exp(−dφ(z,μ(θk)))πkexp(−dφ(z,μ(θk)))
对于每个类有一个聚类的等加权混合模型,聚类分配推理等价于查询类预测
f
ϕ
(
x
)
=
z
f_\phi(\mathbf{x})=\mathbf{z}
fϕ(x)=z 和
c
k
=
μ
(
θ
k
)
\mathbf{c}_k=\boldsymbol{\mu}\left(\boldsymbol{\theta}_k\right)
ck=μ(θk). 在这种情况下,原型网络可以有效地利用
d
φ
d_{\varphi}
dφ确定的指数族分布进行混合密度估计。因此,距离的选择指定了关于嵌入空间中的类条件数据分布的建模假设。
2.4重新解释为一个线性模型
一个简单的分析有助于深入了解所学习的分类器的本质。当我们使用欧几里得距离
d
φ
(
z
,
z
′
)
=
∣
∣
z
−
z
,
∣
∣
2
d_{\varphi}\left(\mathbf{z}, \mathbf{z}^{\prime}\right) = ||z - z^,||^2
dφ(z,z′)=∣∣z−z,∣∣2时,那么方程中的模型等价于一个具有特定参数化的线性模型。要看到这一点,请在指数中展开该项:
−
∥
f
ϕ
(
x
)
−
c
k
∥
2
=
−
f
ϕ
(
x
)
⊤
f
ϕ
(
x
)
+
2
c
k
⊤
f
ϕ
(
x
)
−
c
k
⊤
c
k
-\left\|f_\phi(\mathbf{x})-\mathbf{c}_k\right\|^2=-f_\phi(\mathbf{x})^{\top} f_\phi(\mathbf{x})+2 \mathbf{c}_k^{\top} f_\phi(\mathbf{x})-\mathbf{c}_k^{\top} \mathbf{c}_k
−∥fϕ(x)−ck∥2=−fϕ(x)⊤fϕ(x)+2ck⊤fϕ(x)−ck⊤ck
式(7)中的第一项对于类
k
k
k, s是常数所以它不影响软最大概率。我们可以将剩下的项写成线性模型,如下所示::
2
c
k
⊤
f
ϕ
(
x
)
−
c
k
⊤
c
k
=
w
k
⊤
f
ϕ
(
x
)
+
b
k
, where
w
k
=
2
c
k
and
b
k
=
−
c
k
⊤
c
k
2 \mathbf{c}_k^{\top} f_\phi(\mathbf{x})-\mathbf{c}_k^{\top} \mathbf{c}_k=\mathbf{w}_k^{\top} f_\phi(\mathbf{x})+b_k \text {, where } \mathbf{w}_k=2 \mathbf{c}_k \text { and } b_k=-\mathbf{c}_k^{\top} \mathbf{c}_k
2ck⊤fϕ(x)−ck⊤ck=wk⊤fϕ(x)+bk, where wk=2ck and bk=−ck⊤ck
在这项工作中,我们主要关注欧几里得距离的平方(对应于球形高斯密度)。我们的结果表明,欧几里得距离是一个有效的选择,尽管等价于线性模型。我们假设这是因为所有所需的非线性都可以在嵌入函数中学习。事实上,这是现代神经网络分类系统目前使用的方法.
2.5 与匹配网络的比较
原型网络不同于匹配网络在小样本情况下的等价只有一次的场景。匹配网络给出一个加权最近邻分类器支持集,而原型网络在欧式距离平方时产生线性分类器使用。在一次学习的情况下, c k = x k c_k = x_k ck=xk,因为每个类只有一个支撑点,并且匹配网络和原型网络是等价的。一个自然的问题是,每个类使用多个原型而不是一个原型是否有意义。如果每个类的原型数量是固定的并且大于1,那么这将需要一个分区方案来进一步聚类类内的支持点。这已经在Mensink等人[21]和Rippel等人[27]中提出;然而,这两种方法都需要一个单独的划分阶段,从权重更新解耦,而我们的方法易于用普通的梯度下降方法学习。Vinyals等人。[32]提出了许多扩展,包括解耦支持点和查询点的嵌入函数,以及使用第二级、全条件嵌入(FCE),它在每个事件中考虑到特定的点。这些同样可以被合并到原型网络中,但是它们增加了可学习参数的数量,并且FCE使用双向LSTM对支持集进行任意排序。相反,我们展示了使用简单的设计选择实现相同水平的性能是可能的,我们将概述这一点.
2.6设计选择
距离度量维尼亚尔斯等人,[32]和Ravi和拉罗谢尔[24]应用了使用余弦距离的匹配网络。然而,对于原型网络和匹配网络,任何距离都是允许的,我们发现使用平方欧氏距离可以大大提高两者的结果。对于典型网络,我们推测这主要是由于余弦距离不是布雷格曼散度,因此与第2.3节中讨论的混合密度估计的等价性不成立。
集合的选择在Vinyals等人[32]和Ravi和Larochelle[24]中使用的一种简单的构建集合的方法是选择每个类的Nc类和Ns支持点,以匹配测试时的预期情况。也就是说,如果我们期望在测试时进行5-way分类和1-shot学习,则训练集可以由Nc = 5, NS = 1组成。然而,我们已经发现,使用比测试时使用的更高的Nc或“方式”进行训练是非常有益的。在我们的实验中,我们在一个保留的验证集上调整训练Nc。另一个需要考虑的问题是在训练和测试时是否匹配NS(即“样本”)。对于原型网络,我们发现通常最好使用相同的“样本”数字进行训练和测试。
2.7零样本学习
零样本学习与少样本学习的不同之处在于,我们没有给出训练点的支持集,而是为每个类提供了一个类元数据向量 v k v_k vk。这些可以提前确定,或者可以从例如原始文本[8]中学习。修改原型网络以处理零镜头情况很简单:我们简单地定义 c k = g ϑ ( v k ) c_k = g_ϑ(v_k) ck=gϑ(vk)为元数据向量的单独嵌入。图1显示了原型网络的零样本过程,因为它与少样本过程相关。由于元数据向量和查询点来自不同的输入域,我们发现将原型嵌入g固定为单位长度是有帮助的,但我们没有限制查询嵌入f。
3实验
对于少样本学习,我们使用Ravi和Larochelle[24]提出的分割,在Omniglot[18]和ILSVRC-2012[28]的miniImageNet版本上进行实验。我们在加州理工大学UCSD鸟类数据集(CUB-200 2011)[34]的2011版本上进行零样本实验。
3.1Omniglot小样本分类
Omniglot[18]是一个从50个字母中收集的1623个手写字符的数据集。每个字符都有20个相关的例子,每个例子都是由不同的人类主体绘制的。我们遵循Vinyals等人[32]的程序,将灰度图像调整为28 × 28,并以90度的倍数旋转增加字符类。我们使用1200个字符加上旋转进行训练(总共4800个类),其余的类(包括旋转)用于测试。我们的嵌入体系结构反映了Vinyals等人使用的[32],由四个卷积块组成。每个块包括一个64滤波器3 × 3卷积,批规范化层[12],一个ReLU非线性和一个2 × 2最大池化层。当应用于28 × 28的Omniglot图像时,这种架构会产生64维的输出空间。我们使用相同的编码器嵌入支持点和查询点。我们所有的模型都是通过SGD和Adam[13]训练的。我们使用10−3的初始学习率,并每2000集将学习率降低一半。除批规格化外,未使用规格化。我们在1样本和5样本场景中使用欧几里得距离训练原型网络,训练集包含60个类和每个类5个查询点。我们发现,将训练镜头与测试镜头相匹配是有利的,并且每个训练集使用更多的类(更高的“方式”)而不是更少。我们与各种基线进行比较,包括神经统计学家[7],元学习器LSTM [24], MAML[9],以及匹配网络[32]的微调和非微调版本。我们计算了从测试集中平均超过1000个随机生成集的模型的分类精度。结果如表1所示,据我们所知,在这个数据集上与最先进的技术相竞争。图2显示了prototype Networks学习的嵌入的示例t-SNE可视化[20]。我们将来自相同字母的测试字符子集可视化,以便获得更好的洞察力,尽管实际测试集中的类可能来自不同的字母。即使可视化的字符彼此之间是微小的变化,网络也能够将手绘字符紧密地聚集在类原型周围。
3.2miniImageNet小样本分类
miniImageNet数据集最初由Vinyals等人提出。[32],是从更大的ILSVRC-12数据集[28]派生出来的。Vinyals et al.[32]使用的分割由60000张大小为84 × 84的彩色图像组成,分为100个类,每个类600个示例。在我们的实验中,我们使用了Ravi和Larochelle[24]引入的分割,以便直接与最先进的少镜头学习算法进行比较。他们的分组使用了不同的100个类,分为64个训练类、16个验证类和20个测试类。我们遵循他们的程序,在64个训练类上进行训练,并仅使用16个验证类来检测泛化性能。
我们使用了与Omniglot实验中相同的四块嵌入架构,尽管在这里,由于图像大小的增加,它产生了一个1600维的输出空间。我们还使用与Omniglot实验相同的学习率计划,并进行训练,直到验证损失停止改善。我们使用30类进行1样本分类,使用20类进行5样本分类。我们匹配训练镜头到测试镜头,每个类每轮包含15个查询点。我们与Ravi和拉罗谢尔[24]报告的baseline进行了比较,其中在分类网络学习的特征之上包括一个简单的最近邻方法。其他基线是匹配网络(两种普通网络和FCE)和元学习者LSTM的两种非微调变体。在非微调设置中,因为Vinyals等人[32]提出的微调过程没有被完全描述。从表2中可以看出,典型网络在5样本精度上达到了最先进的水平。
我们进行了进一步的分析,以确定距离度量和每集训练课的数量对原型网络和匹配网络性能的影响。为了使这些方法具有可比性,我们使用了自己的匹配网络实现,它利用了与原型网络相同的嵌入架构。在图3中,我们比较了1次和5次场景下的余弦与欧几里得距离以及5次和20次训练集,每集每个类有15个查询点。我们注意到20-way比5-way获得了更高的精度,并推测20-way分类难度的增加有助于网络更好地泛化,因为它迫使模型在嵌入空间中做出更细粒度的决策。此外,使用欧几里得距离比余弦距离大大提高了性能。这种效应在原型网络中更为明显,在原型网络中,计算类原型作为嵌入支撑点的平均值更自然地适合于欧几里得距离,因为余弦距离不是布雷格曼散度。
3.3 CUB零样本分类
为了评估我们的零镜头学习方法的适用性,我们还在加州理工大学ucsd Birds (CUB) 200-2011数据集[34]上进行了实验。CUB数据集包含200种鸟类的11788张图像。我们严格按照Reed et al.[25]的流程来准备数据。我们使用他们的分割将班级分为100个训练,50个验证和50个测试。对于图像,我们使用通过将GoogLeNet[31]应用于原始和水平翻转图像2的中间、左上、右上、左下和右下作物来提取的1024维特征。在测试时,我们只使用原始图像的中间部分。对于类元数据,我们使用CUB数据集提供的312维连续属性向量。这些属性编码了鸟类的各种特征,如颜色、形状和羽毛图案。
我们在1024维的图像特征和312维的属性向量上学习了一个简单的线性映射,以生成1024维的输出空间。对于这个数据集,我们发现将类原型(嵌入属性向量)规范化为单位长度很有帮助,因为属性向量来自与图像不同的领域。训练集由50个类和每个类10个查询图像构成。在固定学习速率为10−4、权值衰减为10−5的情况下,采用Adam的SGD算法对嵌入进行优化。在验证损失时的早期停止用于确定在训练加验证集上重新训练的最佳epoch数。
表3显示,与使用属性作为类元数据的方法相比,我们获得了最先进的结果。我们将我们的方法与各种零镜头学习方法进行比较,包括其他嵌入方法,如ALE [1], SJE[2]和DS-SJE/DA-SJE[25]。我们还比较了最近的聚类方法[19],该方法在由微调获得的学习特征空间上训练SVM AlexNet[16]。[6]的综合分类器方法是一种将类元数据空间与可视化模型空间对齐的流形学习技术,Zhang和Saligrama[36]的方法是一种在VGG-19特征[30]上训练的结构化预测方法。由于Zhang和Saligrama[36]是随机方法,我们将他们报告的误差条纳入表3。我们的原型网络优于综合分类器,并且在Zhang和Saligrama[36]的误差范围内,同时是一种比两者都简单得多的方法
我们还使用更强的类元数据进行了一组额外的零样本实验。我们使用[25]的预训练的Char CNN-RNN模型为每个CUB-200类提取1024维元数据向量,然后使用上述相同的过程训练零镜头原型网络,除了我们使用了通过验证精度选择的512维输出嵌入。我们获得了58.3%的测试精度,相比之下,DS-SJE[25]使用Char CNN-RNN模型获得了54.0%的精度。此外,我们的结果超过了DS-SJE获得的56.8%的准确率,甚至更强的Word CNN-RNN类元数据表示。总的来说,这些零镜头分类结果表明,即使数据点(图像)来自相对于类(属性)的不同领域,我们的方法也足够通用。
4相关工作
关于度量学习的文献很多[17,5];我们在这里总结了与我们提出的方法最相关的工作。邻域成分分析(NCA)[10]学习马氏距离以最大化k -近邻(KNN)在转换空间中的遗漏精度。Salakhutdinov和Hinton[29]利用神经网络对NCA进行了扩展。大边界最近邻(LMNN)分类[33]也试图优化KNN的准确性,但使用铰链损失来鼓励一个点的局部邻域包含具有相同标签的其他点。DNet-KNN[23]是另一种基于边缘的方法,它改进了LMNN,利用神经网络来执行嵌入,而不是简单的线性变换。在这些方法中,我们的方法最类似于NCA[29]的非线性扩展,因为我们使用神经网络来执行嵌入,并且我们基于转换空间中的欧几里得距离优化软最大值,而不是边际损失。我们的方法和非线性NCA之间的一个关键区别是,我们直接在类上形成一个softmax,而不是从每个类的原型表示的距离计算出来的单个点。这使得每个类都有一个独立于数据点数量的简洁表示,并且避免了存储整个支持集来进行预测的需要。
我们的方法也类似于最近的类均值方法[21],其中每个类由其示例的均值表示。这种方法是为了在不重新训练的情况下快速将新类合并到分类器中,然而它依赖于线性嵌入,并被设计用于处理新类带有大量示例的情况。相比之下,我们的方法利用神经网络来非线性嵌入点,并将其与情景训练相结合,以处理少数射击场景。Mensink等人试图扩展他们的方法来执行非线性分类,但他们是通过允许类拥有多个原型来实现的。他们通过在输入空间上使用k-means在预处理步骤中找到这些原型,然后执行其线性嵌入的多模态变体。另一方面,原型网络以端到端方式学习非线性嵌入,不需要这样的预处理,产生一个非线性分类器,每个类仍然只需要一个原型。此外,我们的方法自然地推广到其他距离函数,特别是布雷格曼散度
Wen et al.[35]提出的人脸识别的中心损失与我们的相似,但有两个主要区别。首先,他们学习每个类的中心作为模型的参数,而我们计算原型作为每个集中的标记示例的函数。其次,他们将中心损失与软最大损失结合起来,以防止表示崩溃为零,而我们从原型中构造了软最大损失,自然地防止了这种崩溃。此外,我们的方法是为少数镜头场景而设计的,而不是人脸识别。
相关的少样本学习方法是Ravi和Larochelle[24]中提出的元学习方法。这里的关键见解是LSTM动态和梯度下降可以有效地以相同的方式编写。然后,LSTM可以被训练为自己从给定的事件中训练模型,其性能目标是在查询点上很好地泛化。MAML[9]是另一种用于少样本学习的元学习方法。它试图学习一种表示方法,它可以很容易地适应新数据,并且只需要很少的梯度下降。匹配网络和原型网络也可以被视为元学习的形式,从某种意义上说,它们从新的训练集动态地生成简单的分类器;然而,它们所依赖的核心嵌入是经过训练后固定的。匹配网络的FCE扩展涉及到依赖于支持集的二次嵌入。然而,在很少样本的场景中,数据量是如此之小,以至于简单的归纳偏差似乎工作得很好,而不需要为每一集学习自定义嵌入。
原型网络也与生成建模文献中的神经统计学家[7]相关,它扩展了变分自编码器[14,26],以学习数据集的生成模型,而不是单个点。神经统计学家的一个组成部分是“统计网络”,它将一组数据点总结为一个统计向量。它通过对数据集中的每个点进行编码,获取样本均值,并应用后处理网络来获得统计向量的近似后验。Edwards和Storkey[7]在Omniglot数据集上测试他们的一次性分类模型,他们将每个字符视为一个单独的数据集,并基于其在统计向量上的近似后验与测试点推断的后验具有最小的kl -散度的类进行预测。像神经统计学家一样,我们也为每个类生成一个汇总统计。然而,我们的模型是一个判别模型,适合我们的少样本分类判别任务。
在zero-shot学习方面,Prototypical Networks中嵌入式元数据的使用类似于[3]的方法,两者都预测线性分类器的权重。[25]的DS-SJE和DA-SJE方法还学习了图像和类元数据的深度多模态嵌入函数。与我们不同的是,他们使用经验风险损失来学习。[3]和[25]都没有使用情景训练,这使我们能够帮助加快训练和正则化模型。
5结论
我们提出了一种简单的few-shot学习的方法称作原型网络,其基本思想是,在一个由神经网络学习的表示空间中用样例的平均值来表示每一类。我们通过使用episode训练使得神经网络在few-shot学习中表现的特别好。这种方法比元学习简单并且更有效,即便没有匹配网络进行复杂的拓展也能产生最新的结果(尽管这些方法也可以应用于原型网络)。我们展示了如何通过仔细考虑所选择的距离度量,并通过修改Episode学习过程来大大提高性能。我们进一步展示了如何将原型网络推广到zero-shot setting,并且在CUB-200数据集上实现了最新的结果。未来工作的一个自然方向是利用Bregman发散,而不是平方欧氏距离,对应于超越球面高斯的类条件分布。我们对此进行了初步的探索,包括为一个类学习每个维度的方差。这并没有导致任何经验收益,这表明嵌入网络本身具有足够的灵活性,而不需要每个类的附加拟合参数。总的来说,原型网络的简单性和有效性使其成为一种有前途的few-shot学习方法。