论文《Matching Networks for One Shot Learning》阅读

Matching Networks for One Shot Learning

摘要

In this work, we employ ideas from metric learning based on deep neural features and from recent advances that augment neural networks with external memories. Our framework learns a network that maps a small labelled support set and an unlabelled example to its label, obviating the need for fifine-tuning to adapt to new class types. We then defifine

one-shot learning problems on vision (using Omniglot, ImageNet) and language tasks.

1、Introduction

This motivates the setting we are interested in: “one-shot” learning, which consists of learning a class from a single labelled example.

深度学习虽然在很多方面取得来很大的进步,但是缺点之一是需要大量的数据集。数据加强和正则化虽然可以减轻在只有小数据样本上的过拟合,但是不能从根本上解决这个问题。学习过程非常缓慢并且依赖于大的数据集,需要使用SGD进行很多次的权重更新。这很大程度上是由于模型参数化的原因,模型需要非常缓慢的学习到它的模型。

相反,许多非参数的模型可以快速接受新例子,同时不会产生灾难性的遗忘。 一些非参数的模型不需要训练但是会依赖选择的度量。我们的目标是将参数化和非参数化模型中的最佳特性结合起来-即快速获取新的示例,同时从常见的示例中提供优秀的概括。

我们模型的新颖之处在于两个方面,一方面,在模型层面和训练过程,我们提出了匹配网络,一种利用注意力和记忆的最新进展来实现快速学习的神经网络。另一方面,我们的训练程序是基于一个简单的机器学习原则:测试和训练条件必须匹配。因此,为了训练我们的网络进行快速学习,我们训练它,每个类只显示几个示例,将任务从minibatch切换到minibatch。

在下文将介绍模型,一般设置和实验。

2 Model

非参数的方法去解决one-shot问题基于两个方面,首先是模型结构跟随了神经网络记忆增强方面的最新进展,给一个支持集,模型定义一个函数,对每一个支持集都做一个映射。其次,采用一种训练策略,which is tailored for one-shot learning from the support set S.

2.1 Model Architecture

We draw inspiration from models such as sequence to sequence (seq2seq) with attention [2], memory networks [29] and pointer networks [27].

在这些模型中,neural attention mechanism都是完全可微的,被定义去读记忆矩阵,记忆矩阵中存储了解决任务的有用信息。

set-to-set framework

关键点在于在训练的时候,不需要更改网络,匹配网络更够生成为没见过的类生成合理的测试标签。更准确的说,对于给定的测试样例x,得到一个输出y的概率分布。定义一个映射 P是神经网络参数化的。然后,对于给的一个新的样例支持集,我们使用用P定义的参数神经网络去预测测试样例x的标签y。我们的预测输出实际上是

 

 

上述式子本质上描述了一个新类别的输出是支持集中标签的线性组合。其中注意力机制a是一个x×x的内核,上述式子类似一个内核密度估计器。依据一个距离度量和适当的常数,如果的距离超过b,注意力机制就为0,则上述式子类似‘k-b'最邻近。然而方程1即包括KDE也包括KNN。我们可以把上述理解为一种特殊的联想记忆,给一个输入,我们指向支持集中相应的示例,找到它的标签。然而,与其他注意力记忆机制[2]不同,(1)在本质上是非参数的:随着支持集大小的增加,所使用的记忆也是如此。因此,分类器CS(ˆx)定义的函数形式非常灵活,可以很容易地适应任何新的支持集。

2.1.1 The Attention Kernel

方程一依赖于注意力机制a的选择,它完全指定了分类。最简单的形式是在余弦距离上使用softmax,,并且带有嵌入函数f,g去嵌入x。

虽然与度量学习有关系,但是我们发现用式子1定义的分类器是对于给定的支持集和分类样本,让对充分对齐是足够的。这种损失还与Neighborhood Component Analysis (NCA) [18], triplet loss [9] or large margin nearest neighbor [28]有关系。

损失是简单的可微的,所以我们可以找到一个端到端的参数优化。

2.1.2 Full Context Embeddings

模型的主要新颖至于在于重新解释一个学习的很好的框架(带有外部记忆的神经网络)去做one-shot学习。与度量学习密切相关的是,嵌入函数f和g通过对空间特征X的提升实现最大化在方程一中提到的分类函数的准确率。

尽管分类策略完全依赖于通过设置的整个支持集,我们可以使用余弦相似性去“attend”,“point”或者是简单的进行最邻近计算都是myopic,因为每一个元素x都是通过g(x)独立嵌入的。此外,通过函数f,S能够修改我们是如何嵌入test x的。

我们建议通过一个函数嵌入集合中的元素,该函数除xi外,还包含完整集S,即g变成g(xi,S)。作为整个集合S的函数,g可以修改如何嵌入。当元素xi和xj非常相似时这是非常重要的,在这种情况下,更改嵌入x的函数是有益的。我们使用双向长-短期内存(LSTM)[8]在支持集S的上下文中编码xi,这被认为是一个序列(更精确的定义见附录)。

第二个问题可以通过LSTM来固定,在整个集合S上进行read-attention,输入等价于x:

是从网络中产生的特征,输入到LSTM中,K是LSTM展开步骤(unrolling steps)的固定数目,g(s)是我们得到的集合使用g嵌入。这允许模型去忽视在支持集S中的一些元素,但是把深度加入到attention的计算中。

2.2 Training Strategy

我们定义一个任务T作为所有可能标签集合L上的分布。通常,我们考虑T将所有数据集都统一到几个唯一的类(例如,5),每个类的示例(例如,最多5个)。在这种情况下,从任务T中取得标签集合L,L~T,有5到25个样例。

为了形成一个“Episode”去计算梯度和更新我们的模型,我们首先从T中取样L,L是标签集合。然后我们用L去取样支持集S和batch B(S和B都是被标记的样例)。然后匹配网络被训练去最小化以支持集S为条件的批次B中标签的预测误差。这是一种元学习的形式,因为训练过程明确地学习从给定的支持集学习,以尽量减少一个批的损失。匹配网络训练过程中的目标为:

用方程2训练θ 产生一个模型,当从新的标签不同分布中抽样时,模型会工作的很好。关键的是,我们的模型不需要对它从未见过的类进行任何微调,因为它的非参数性质。Obviously, as T diverges far from the T from which we sampled to learn θ, the model will not work

3 Related Work

3.1 Memory Augmented Neural Networks

fifixed vectorsmore expressive models

3.2 Metric Learning

Many links between content based attention, kernel based nearest neighbor and metric learningThe most relevant work is Neighborhood Component Analysis (NCA)和非线性的版本。在one-shot学习中使用整个支持集更合适。

4Experiments

我们的所有实验都围绕着相同的基本任务:an N-way k-shot学习任务。每种方法都提供了一组K个标记的例子,这些例子来自每一个以前没有接受过训练的N个类。任务是将这些无关的无标签的样例分类到这N个类中。我们将多个备选模型(作为基线)与匹配网络进行了比较。

4.1 Image Classifification Results

For vision problems, we considered four kinds of baselines: matching on raw pixels, matching on discriminative features from a state-of-the-art classififier (Baseline Classififier), MANN [21], and our reimplementation of the Convolutional Siamese Net.

基准分类器被训练去分类一张图片到一个在训练集中原始出现的类,但是不包括之前说的N类。我们使用这个网络并且使用最后一层的特性(在softmax之前)进行最邻近匹配,这在很多任务中都取得了较好的结果。在【11】之后,将卷积连网训练成原来训练数据集中相同或不同的任务,然后使用最后一层进行最近邻匹配。

我们还尝试仅使用从L采样的支持集S来进一步微调特征。这产生了大量的过度拟合,但是考虑到我们的网络是高度正则化的,可以产生额外的增益。

 

4.1.1 Omniglot

Omniglot包括1623个字符,它们来自50个不同的字母表。每一个都是由20个不同的人绘制的。这有很多个类别但是每个类只有很少的样例。

The N-way Omniglot task:选择N个没有见过的字符类,独立于字母表作为L。为模型提供一个类别一张图片,作为S~L,B~L。我们通过旋转90的倍数来增强数据,使用1200个字符作为训练剩下的用作评估。

我们使用CNN作为嵌入函数,由一堆模块组成,每一个都是3×3的卷积带有64个滤波器,然后是批量归一化,一个非线性的ReLu,2×2的最大池化层。我们将图片调整为28×28,这样我们使用4个模块就可以1×1×64的结果特征映射,从而产生我们的嵌入函数f(x)。一个全连接层后跟着softmax用来定义基准分类器。

  1. shot,5-shot,5-way,20-way,我们的模型都比基准表现好。对于k-shot分类器使用更多的样例是有帮助的;5-way分类比20-way分类更简单。

4.1.2 ImageNet

我们实验的设置和Omniglot相同,但是我们考虑一个rand和dogs设置。在rand设置中,我们随机从训练集中移除了118个标签,然后训练只在这118个类中,我们表示为。对于dogs设置,我们从狗的后代中删除了所有类别(共118个)然后在没有狗的类别上进行训练,然后在狗的类别上进行测试,。我们设计了一个新的数据集minImageNet,包括60000张大小为84×84的彩色图片,有100个类,每一个有600个样例。我们使用80个类训练,然后再剩余的20个类上进行测试,所以我们现在有 randImageNet,dogsImageNet,miniImageNet。

和Omniglot一样,匹配网络比基准网络表现的好。但是miniImageNet比Omniglot任务要难,它让我们去评估Full Contextual Embeddings的灵活度。不管有没有微调,FCE提高了匹配网络的表现。

 

我们在全尺寸的InamgeNet上做实验。我们的基准分类器是Inception。We also compared to features from an Inception Oracle classififier trained on all classes in ImageNet, as an upper bound. 我们用从Inception 分类器上得到的参数初始化匹配网络的特征提取器f和g,而不是从零开始在这些大的任务上训练匹配网络,然后我们进一步在5-way 1-shot任务上训练数据集,结合Full Context Embeddings和我们的匹配网络和训练策略。

randImageNet和dogsImageNet的结果展示在表3中。Inception Oracle的表现接近完美。

当仅在上训练时,匹配网络比Inception提高了进6%,当在上训练时,将错误率减半。从所有的错误来看,“盗梦空间”有时似乎更喜欢图像,而不是其他图像(这些图像往往像第二列中的示例一样混乱,或者颜色更恒定)。另一方面,匹配网设法从支持集S’中出现的这些异常值中恢复。

如果我们将我们的训练策略调整为来自细粒度集的样本S而不是从Image Net类树的叶子上统一采样F标签,可以实现改进。我们把这作为今后的工作。

4.1.3 One-Shot Language Modeling

任务如下:给一个缺少词的查询句,和一组支持句,每个句子都有一个缺失的单词和一个对应的标签,从支持集中选择最匹配查询句的标签。

 

句子来自the Penn Treebank dataset。在每一次试验中,我们确保集合和批处理都填充了不重叠的句子。这意味着我们不使用频率很低的单词。和图片任务一样, each trial consisted of a 5 way choice between the classes available in the set。在整个句子匹配任务中,我们使用了20的批处理大小,并且在k=1,2,3之间改变了设置大小。我们确保每一组都有相同数量的句子可供使用。我们将单词分成随机抽样的9000个用于训练,1000个用于测试,我们使用标准测试集来报告结果。因此,无论是单词还是句子在测试期间都没有在训练期间见过。

我们将我们的one-shot匹配模型与oracle LSTM进行比较。在设置中,LSTM具有一个不公平的优势,因为它不是做one-shot学习而是看到所有的数据,所有这应该被当作一个上限。我们检验一个相似的设置,其中给模型一个呆着一个空的句子并且还有五个可能的单词,其中包括正确答案。对于这五个词,模型给出了一个对数似然,并选择其中数值最大的。

LSTM语言模型oracle在测试集上达到了72.8的准确率。. Matching Networks

with a simple encoding model achieve 32.4%, 36.1%, 38.2% accuracy on the task with k = 1, 2, 3 examples in the set, respectively.

Two related tasks are the CNN QA test of entity prediction from news articles [5], and the Children’s Book Test (CBT) 

5 Conclusion

我们在这篇论文中介绍了匹配网络,一种新的神经网络结构,通过相应的训练制度,能够对各种one-shot分类任务执行最新的性能。这里有几个关键的点,首先,如果你训练网络进行one-shot学习,那么one-shot学习就变得容易。其次,神经网络中的非参数结构使得网络在相同的任务中更容易记忆和适应新的训练集。将这些观测结果结合起来,产生匹配网络。我们模型的一个明显的缺点是,随着支持集的大小增加,每次梯度更新的计算变得更加昂贵。尽管有稀疏的和基于抽样的方法来缓解这一问题,但我们未来的许多努力将集中在这一限制上。此外,如ImageNet dogs子任务中所示,当标签分布具有明显的偏差(例如细粒度)时,我们的模型会受到影响。

 

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值