《Prototypical Networks for Few-shot Learning 》论文翻译

                  Prototypical Networks for Few-shot Learning

 

Abstract

我们为小样本分类问题提出了原型网络,其中分类器能够很好的泛化到其他没有在训练集中出现的新类别,对于每一种新出现的类别,我们只给出很少的样本。原型网络学习一个度量空间,在该空间中,可以通过计算到每个类的原型表示的距离来执行分类。与最近的小样本学习方法相比,它们反映了一种更简单的归纳偏差,有利于在这种有限的数据范围内使用,并取得优异的效果。我们提供了一个分析,表明一些简单的设计决策比最近涉及复杂体系结构选择和元学习的方法可以产生实质性的改进。我们进一步将原型网络扩展到零样本学习学习,并在CU Birds数据集上获得最新结果。

1.Introduction

小样本分类是一个任务,其中一个分类器必须能够进行调整以适应训练过程没有见过的类,并且这些新出现的类只给出很少的样本。一种简单的方法,比如在新数据上重新训练模型,极有可能会出现过拟合问题。虽然这个问题很困难,但是已经表明人类有能力去执行甚至一个样本的分类,每个新类只给了一个例子,并且具有很高的准确性。

最近有两种方法在小样本学习上取得了重大进步,Vinyals等人提出匹配网络,它应用注意机制的学习嵌入处理标记好的样本(支持集)来预测未标记的点(查询集)的类别。匹配网络可以被解释为应用于嵌入空间中的加权最邻近分类器。值得注意的是,该模型在训练的时候应用成为片段的抽样小批量,其中每个片段都被设计通过子采样类和数据点去模仿小样本任务。Episodes的使用使得训练问题更加适应测试环境,从而提高泛化能力。Ravi 和 Larochelle进一步利用了Episodes训练的想法并提出了一种元学习的方法用于小样本学习。他们的方法包括在一个给定Eisode的情况下训练一个LSTM去去更新一个分类器,这样就会很好的推广到测试集。在这里,LSTM元学习不是在多个Episodes上训练单个模型,而是学习为每一个Episode训练一个合适的模型。

我们通过解决关键问题过拟合来解决小样本学习的问题。由于数据非常的有限,我们假设一个分类器应该应该有非常简单的inductive bias(不知道应该怎么翻译).我们的方法,原型网络,是基于存在embedding的想法,这些点围绕每个类的单个原型表示进行聚集。为了做到这点,我们学习了一个非线性的映射,利用神经网络将输入映射到一个embedding space,并且在embedding space中将支持集中平均值作为类的原型。然后通过简单的查找最近的类原型对嵌入的查询点进行分类。我们利用同样的方法处理零样本学习,在这里,每个类都带有给出类高度描述的元数据,而不是少量的标记数据。因此我们学习将元数据嵌入到共享空间中,作为每个类的原型。

分类是通过为一个embedded query point(嵌入式查询点)找到最近的类原型来执行的在小样本场景中。

在这篇论文中,我们给few-shot和 zero-shot settings制定了原型网络。我们还与one-shot setting中的匹配网络建立联系,并且分析下层距离函数在模型中的应用。特别的,我们将原型网络与聚类联系起来,,以证明当使用Bregman散度(如平方欧氏距离)计算距离时,类平均值作为原型的使用是正确的。我们从经验上发现,距离的选择是至关重要的,因为欧几里德距离大大优于更常用的余弦相似度。在几项核心任务中,我们实现了最先进的性能。原型网络比最近的元学习算法更简单、更有效,使其成为一种小样本学习和零样本学习方法。

2 Prototypical Networks

2.1 符号说明

在小样本分类中,我们有一个小的支持集,有N个标记的样本S = {(x1, y1), . . . ,(xN , yN )}

其中:,是每个样例的D维特征向量,是对应的标记,

是类别k的已标记的样例

2.2模型

原型网络通过带有可学习参数的嵌入函数,为每一个类计算一个M为的原型,每个原型是属于其类的嵌入式支持点的平均向量:

 

给一个距离函数,原型网络为一个查询点x生成一个类上的分布,这个查询点x基于一个softmax在嵌入空间中原型的距离上

 

通过SGD来最小化所属真实类别x的负log概率来进行学习过程。训练集是通过从训练集中随机选择一个类的子集,然后在每个类中选择一个示例子集作为支持集,其余的子集作为查询点而形成的。在训练过程中计算损失的伪代码如下:

 

Alorithm 1:

N:训练集中样例的数量

K:训练集中类的数量

NC:每个Episode中类别的数量

NS:每个类中支持样例的数量

NQ:每个类中查询样例的数量

RandomSample(S,N):denotes a set of N elements chosen uniformly at  random from set S, without replacement.

输入:训练集D,Dk代表类别k的子集

输出:随机产生的训练Episode的损失J

计算过程:为Episode选择类别-》选择支持集-》选择训练集-》计算支持集的原型-》初始化损失-》更新损失

2.3作为混合密度估计的原型网络

对于一类特殊的距离函数,如我们所知道的正则Bregman散度,原型函数相当于对具有指数族密度的支持集进行混合密度估计。正则Bregman散度定义如下:

φ是可微的,Legendre型严格凸函数。Bregman散度的例子包括平方欧式距离和Mahalanobis distance

原型计算可以通过支持集上的hard clustering来查看,每个类一个簇并且每个支持点都分配给它相应的群集。Bregman散度的结果表明,达到最小距离的聚类代表是聚类平均数。因此,当使用Bregman散度时,方程(1)中的原型计算在给定支持集标签的情况下产生最佳簇代表。

此外,任何具有参数和积累量的正则指数分布都可以用唯一正确的正则Bregman散度表示:

 

考虑带参数的正则指数分布混合模型

 

对未标记点z的分布赋值y的推断变为:

 

对于每一个类有一个集群的等权混合模型,集群分配推理等价于用

进行查询类预测。在这种情况下,原型网络有效地执行了混合密度估计,其指数分布用dϕ确定。因此,距离的选择指定了关于嵌入空间中类条件数据分布的建模假设。

2.4 重新解释为线性模型

简单的分析有助于深入了解学习的分类器的性质。当我们使用欧氏距离

时,方程(2)中的模型等价于等价于具有特定参数化的线性模型。 要查看此项,展开指数中的项:

在方程(7)中的第一项是关于第k类的一个常数,所以它不影响softmax的概率。我们可以把剩下的项写成一个模型,如下:

我们主要关注平方欧式距离(对应于高斯距离)。结果表明,尽管欧氏距离与线性模型等价但是它是一种有效的选择。我们假设这是因为所有需要的非线性都可以在嵌入函数中学习。实际上,这是现代神经网络分类系统目前使用的方法。

2.5 与匹配网络的比较

原型网络和匹配网络在 few-shot场景下是不同的,但是在one-shot场景下是等价的。匹配网络在给定支持集的情况下产生加权最近邻分类器,而原型网络在使用平方欧氏距离时产生线性分类器。在一次学习的情况下,ck=xk,因为每个类只有一个支持点,匹配网络和原型网络变得等价。

一个自然的问题是每个类使用多个原型而不是一个是否有意义。如果每个类的原型固定而且大于1,那么这将需要一个分区方案去进一步集群每一个类中的支持点。Mensinket al.  and Rippel et al 等人已经提出过,但是这两种方法都需要一个独立的分区阶段,该阶段与权重更新分离,而我们的方法使用普通的梯度下降方法很容易学习。

Vinyals等人提出了一些扩展,包括分离支持点和查询点的嵌入函数,使用第二级的完全条件嵌入(FCE),它考虑到每一Episode的特定点。它们同样可以被纳入到原型网络中,但是它们增加了可学习参数的数量,并且FCE使用双向LSTM对支持集施加任意顺序。相反,我们展示了使用简单的设计选择来实现相同级别的性能是可能的,下面我们将概述这一点。

2.6 设计选择

距离度量:对于原型网络和匹配网络,任何距离都是允许的,并且我们发现使用平方欧几里德距离可以极大地改善这两者的结果。我们推测这主要是由于余弦距离不是Bregman散度,因此第2.3节中讨论的混合密度估计的等价性不成立。

Episode composition :一种简单的构造Episodes的方法,是每个类选择Nc类和Ns支持点,以便在测试时匹配预期的情况。也就是说,如果我们期望在测试时执行一个5个类一个样本的学习,训练episodes可以用Nc=5,Ns=1组成。然而,我们发现,使用比测试时更高的Nc或“way”进行训练是非常有益的。在我们的实验中,我们调整训练Nc在一个保持有消极上。另一个考虑因素是在训练和测试时间是匹配Ns还是‘shot’。对于原型网络,我们发现最好的是使用相同的‘shot’进行训练和测试。

2.7 Zero-Shot Learning

Zero-shot与few-shot的不同之处在于,在没有给出训练点的支持集的情况下,我们给出了每个类的类元数据vk。这些可以实现确定,也可以从原始文本中学习。修改原型网络去处理zero-shot问题是很简单那的,我们定义为元数据向量的单独嵌入。图一显示了zero-shot,few-shot以及原型网络。由于元数据向量和查询点来自不同的输入域,我们发现将原型嵌入g固定为单位长度是有帮助的,但是我们不限制查询嵌入f。

3、Experiments

对于few-shot学习,我们在 Omniglot [16] and the miniImageNet version of ILSVRC-2012 上进行实验,对于zero-shot学习我们在2011 version of the Caltech UCSD bird dataset (CUB-200 2011)上进行实验。

3.1 Omniglot Few-shot Classifification

Omniglot是一个从50个字母表中收集的1623个手写字符的数据集。每一个字符有20个样例,每一个样例都是由不同的人绘制的。我们将灰度图像调整为28×28并且通过旋转90度的倍数来增加character classes。我们用1200个字符加上其旋转作为训练,用剩下的类包括旋转的用作测试。我们的嵌入架构由四个卷积块组成。每一块包括64个3×3的卷积滤波器,批处理规范化层,ReLu非线性层和2×2的最大池化层,当应用于28×28的Omniglot图像时,这种结构产生64维的输出空间。我们对嵌入支持点和查询点使用相同的编码器。我们的模型都用Adam进行训练。我们初始化学习速率为10e-3,并且每2000Episode减少一半的学习率。除了批处理规范化,我们没有使用正则化,我们使用欧几里德距离在1-shot和5-shot场景中训练原型网络,训练集包含60个类和每个类5个查询点。我们发现将训练镜头和测试镜头进行匹配是有利的,并且每个训练片段使用更多的类(higher way)而不是使用更少。我们比较了各种基线(baselines),包括 neural statistician和匹配网络的微调和非微调的版本。我们计算了我们的模型的分类精度,平均超过1000个随机从测试集产生的Episode。结果如表1所示,据我们所知,它们代表了该数据集的最新技术。

3.2 miniImageNet Few-shot Classifification

miniImageNet数据集一开始由Vinyals等人提出,是从较大的ILSVRC-12数据集分离出来的。miniImageNet包括60000张大小为84×84的彩色图片,图片被分成100个类并且每个类有600个样例。在我们的实验中,我们使用由Ravi和Larochelle引入的分离,以便直接与小样本学习中最先进的算法进行比较。他们的分组使用了一组不同的100个类,分为64个训练类、16个验证类和20个测试类。我们遵循他们的程序,在64个训练类上进行训练,并使用16个验证类来检查泛化能力。

我们使用与Omniglot实验相同的四块嵌入架构,但由于图像尺寸的增加,这里的输出空间为1600维。我们还使用与Omniglot实验相同的学习率计划,并进行训练,直到验证损失停止改善。We train using 30-way episodes for 1-shot classification and 20-way episodes for 5-shot classification.我们匹配train shot和test shot,并且每个Episode每个类包括15个查询点。我们比较了Ravi和Larochelle报告的baselines,其中包括一个简单的最邻近方法,该方法基于64个训练类上分类网络所学习的特征。另一个基线是匹配网络(普通和FCE)和元学习者LSTM的两个非微调变体。如表2所示,典型网络在这方面达到了最先进的水平。

 

我们进行了进一步的分析,以确定距离度量和每Episode中训练classes的数量对原型网络和匹配网络的影响。为了使这些方法更具有可比性,我们使用我们自己的匹配网络实现,它使用与我们的原型网络相同的嵌入架构。在图2中,我们比较了余弦距离与欧式距离,5-way和20-way  training episodes在1-shot和5-shot场景中,每个Episode每个类中有15个查询点。 我们注意到20-way比5-way获得了更高的准确率,并且推测20-way分类难度的增加有助于网络更好的泛化,因为它迫使模型在嵌入空间中做出更细粒度的决策。此外,使用欧氏距离比预先距离大大提高了性能。这种效果对于原型网络更为明显,在这种网络中,将类原型作为嵌入支持点的平均值进行计算更适合于欧氏距离,因为余弦距离不是Bregman散度。

3.3 CUB Zero-shot Classifification

为了评估我们的方法对zero-shot学习的适用性,我们在Caltech-UCSD Birds (CUB) 200-2011 数据集上进行了实验。CUB数据集包括11788张200种鸟类。在准备数据时严格遵守Reed等人的程序。我们将类划分为100个训练集,50个验证集,50个测试集。对于图像,我们使用通过对原始和水平翻转图像的中间、左上、右上、左下和右下裁剪应用GoogLeNet[28]提取的1024维特征。在测试时我们只使用原始图像的中间部分,对于类元数据,我们使用CUB数据集提供的312维连续属性向量。这些属性编码鸟类的各种特征,如颜色、形状和羽毛图案。我们在1024维图像特征和312维属性向量的基础上学习了一个简单的线性映射来生成1024维输出空间。对于这个数据集,我们发现将类原型(嵌入的属性向量)规范化为单位长度很有帮助,因为属性向量来自与图像不同的域。训练Episode由每个类的50个classes和10个query图片组成。

在固定学习速率为10e-4和weight decay10-5的情况下,通过与Adam的SGD优化embedding。Early stopping on validation loss was used to determine the optimal number of epochs for retraining on the training plus validation set.表3显示,与使用属性作为类元数据的方法相比,我们可以获得更大幅度的最新结果。我们将我们的方法与其他嵌入方法进行比较,例如ALE、SJE和DS-SJE/DA-SJE。我们还比较了最近的聚类方法,该方法在通过微调AlexNet获得的学习特征空间上训练支持向量机。这些zero-shot结果表明,即使数据点(图像)来自与类(属性)相关的不同域,我们的方法也足够通用。

 

4 Related Work(相关工作)

关于度量学习的文献很多,我们在这里总结与我们提出的方法最相关的工作。Neighborhood Components Analysis (NCA)学习Mahalanobis distance 以最大限度提高knn在变换空间种的leave-one-out accuracy 。Salakhutdinov和Hinton利用对神经网络执行转换对NCA进行了扩展。

Large margin nearest neighbor 大边距最近邻(LMNN)分类也试图优化KNN的精度,但使用的hinge loss铰链损失鼓励一个点的局部邻域包含具有相同标签的其他点。DNet KNN[21]是另一种基于边距的方法,它通过使用神经网络来执行嵌入而不是简单的线性变换来改进LMNN。其中,我们的方法与NCA[27]的非线性扩展最为相似,因为我们使用神经网络来执行嵌入,并且我们基于变换空间中的欧氏距离来优化softmax,而不是margin loss。我们的方法和非线性NCA之间的一个关键区别是,我们直接在类上而不是单个点上形成一个softmax,它是根据到每个类的原型表示的距离来计算的。这使得每个类都有一个与数据点数量无关的简明表示,并且避免了存储整个支持集以进行预测的需要。

我们的方法也类似于nearest class mean approach(最近类平均方法),其中每个类都用其示例的平均值表示。这种方法是为了在不需要重新训练的情况下快速地将新类合并到分类器中而开发的,但是它依赖于线性嵌入,并且是为了处理新类附带大量示例的情况而设计的。相反,我们的方法是对非线性嵌入点使用神经网络,并且我们将其与Episode training结合起来去处理few-shot 场景。

Mensink等人尝试扩展它们的方法来执行非线性分类,但是通过允许类具有多个原型实现。他们通过在输入空间上使用k-均值在预处理步骤中找到这些原型,然后对其线性嵌入进行多模态变换。另一方面,原型网络以端到端的方式学习非线性嵌入,而不需要这样的预处理,生成的非线性分类器仍然只需要每个类一个原型。此外,我们的方法自然地推广到其他距离函数,特别是Bregman divergences。

另一种和few-shot学习相关的是Ravi和Larochelle提出的元学习方法。这个的关键是 LSTM dynamics and gradient descent can be written in effectively the same way. 。LSTM可以被训练为自己从给定的Episode中训练一个模型,其性能目标是在查询点上很好地泛化。匹配网络和原型网络也可以看作元学习的形式,因为它们从新的训练片段中动态地生成简单的分类器;然而,他们所依赖的 core embeddings核心嵌入是在训练后固定的。匹配网络的FCE扩展涉及依赖于支持集的二次嵌入。然而,在少数镜头场景中,数据量非常小,一个简单的归纳偏差似乎很有效,无需为每集学习自定义嵌入。

Prototypical networks are also related to the neural statistician from the generative modeling literature, which extends the variational autoencoder to learn generative models of datasets rather than individual points.  neural statistician的一个组成部分是“统计网络”,它将一组数据点归纳为一个统计向量。它通过对数据集中的每个点进行编码,取一个样本均值,

并应用后处理网络获得统计向量上的近似后验。Edwards and Storkey  test their model for one-shot classification on the Omniglot dataset by considering each character to be a separate dataset and making predictions based on the class whose approximate posterior over the statistic vector has minimal KL-divergence from the posterior inferred by the test point. 像neural statistician一样,我们也为每一个类产生一个汇总统计。然而,我们的模型是一个判别模型,适合于我们进行few-shot分类的判别任务。

关于zero-shot学习,在原型网络中使用嵌入元数据类似于之前的方法,因为两者都预测线性分类器的权重。DS-SJE和DA-SJE方法还学习了图像和类元数据的深度多模态嵌入函数。与我们不同,他们学习使用经验风险损失。[3]和[23]都没有使用阶段性训练,这使得我们能够帮助加快训练并使模型正规化。

6、Conclusion

我们提出了一种简单的few-shot学习的方法称作原型网络,其基本思想是,在一个由神经网络学习的表示空间中用样例的平均值来表示每一类。我们通过使用episode训练使得神经网络在few-shot学习中表现的特别好。这种方法比元学习简单并且更有效,即便没有匹配网络进行复杂的拓展也能产生最新的结果(尽管这些方法也可以应用于原型网络)。我们展示了如何通过仔细考虑所选择的距离度量,并通过修改Episode学习过程来大大提高性能。我们进一步展示了如何将原型网络推广到zero-shot setting,并且在CUB-200数据集上实现了最新的结果。未来工作的一个自然方向是利用Bregman发散,而不是平方欧氏距离,对应于超越球面高斯的类条件分布。我们对此进行了初步的探索,包括为一个类学习每个维度的方差。这并没有导致任何经验收益,这表明嵌入网络本身具有足够的灵活性,而不需要每个类的附加拟合参数。总的来说,原型网络的简单性和有效性使其成为一种有前途的few-shot学习方法。

 

 

  • 4
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值