One-shot Learning with Memory-Augmented Neural Networks

摘要

尽管深度学习应用领域最近取得了较大的进展,但是小样本学习的挑战是一直存在的,传统的基于梯度的网络需要大量的数据去学习,通常需要经过大量广泛的迭代训练。当给模型输入新数据时,模型必须低效的重新学习其参数从而充分的融入新的信息,并不会造成较大的干扰影响。具有增强记忆能力的网络结构,例如NTMs具有快速编码新信息的能力,因此能消除传统模型的缺点。这里,我们证明了记忆增强神经网络(memory-augmented neural network)具有快速吸收新数据知识的能力,并且能利用这些吸收了的数据,在少量样本的基础上做出准确的预测。

我们也介绍了一个访问外部记忆存储器的方法,该方法关注于记忆存储器的内容,这和之前提出的使用基于记忆存储器位置的聚焦机制的方法不同。

  1. 介绍

当前深度学习的成功取决于基于梯度的优化算法应用于高容量模型(神经元数量)的能力。这种方法在许多以原始感官为输入的大型监督任务上已经取得了非常好的结果,例如图像分类、语音识别、游戏等。值得注意的是,这些任务上的表现通常是在大型数据集上经过广泛的增量式训练来评估得出的。相反,许多兴趣问题(many problems of interest)需要从少量的数据中快速推断出结果。在one-shot learning的记心中,单一的观察结果会导致行为的突然转变。

这种灵活的适应是人类学习中一个重要的方面,从发动机的控制到抽象概念的获取都得到了表现。根据少量信息的推断生成新的行为,比如推断一个只在上下文中出现过一两次的单词的全局适用性,这是超出当代智能能力的。这对深度学习提出了严峻的挑战,只有在少数样本逐一呈现的情况下才有一种简单的基于梯度的解决方案:从目前可用数据中完全重新学习参数。但是这种方法往往会导致不良学习和灾难性干扰,这时非参数方法往往被认为更合适。

可是先前的工作提出一种从稀疏数据中学习的策略,并取决于元学习的概念。虽然meta-learning。虽然meta-learning术语已经被用在很多领域。元学习通常考虑为学习两种水平的场景,并且每个水平和不同的时间尺度有关。快速学习通常出现在一个任务内,例如在特定的数据集中进行准确分类。这种学习是由在任务中逐渐积累的知识来指导的,这些知识捕获了任务结构在目标域中的变化方式或变化规律。考虑到这种结构的两层形式,因此也被叫做learning to learn。

已经提出的具有记忆能力的神经网络能够证明确实能够进行元学习。这些网络能够通过权重更新改变偏置的值,并能通过快速学习记忆存储中的缓存表示(cache representations in memory stores)来调整输出结果。例如用lstms当做元学习的网络能根据少量的数据样本就能快速学习到之前没有见过的二次函数。

具有记忆能力的神经网络给元学习在深度网络中提供了一种可行的方法。但是使用非结构化循环网络结构的内在记忆单元这种特定的策略不可能扩展到每个新任务都需要快速编码吸收大量新信息的场景。一个可扩展的解决方案必须有以下必要的要求:首先,信息必须稳定的表现形式存储在记忆存储器中 (以便在需要时可以可靠地访问),并且记忆中的元素可寻址(以便可以选择性的访问信息);其次,参数的数量不应该和记忆存储器的大小有关联。标准的具有记忆的结构例如LSTMs并没有这两种特性。然而最近的架构中如神经图灵机NTMS和记忆网络满足了这两个特点的要求。因此在文中我们从一个高容量的记忆增强神经网络的角度重新考虑了元学习的问题和设置(setting),(注:这里MANN指配备外部记忆的网络,而不是其他内部记忆单元的架构如LSTM)。

我们的方法结合最有利的两部分:通过梯度下降慢慢的从原数据中获取有用表示(representations)的抽象方法;通过外部记忆存储模块在一次表示之后(after a single presentation)快速吸收没有见到过的知识。这种结合使元学习更加健壮,并扩展了可以有效应用深度学习的问题范围。

  1. 元学习任务方法论

通常我们选择一个参数θ在数据集D上去最小化损失函数L。可是对于元学习来说,我们选择参数来降低数据集分布P(D)中的期望损失。

要做到这一点,正确的任务设置至关重要。在我们的任务设置中,一个任务或者插曲片段(a task, or episode)涉及一些数据集D的表示。Yt既是一个目标,也是以时间偏移的方式与xt一块作为输入。这个网络的目的是在给定的时间戳t上为xt输出正确的标签。重要的是,标签是从数据集中混洗得到的,这样能够防止网络缓慢的学习样本和类的绑定关系来更新权重。相反的的是,网络必须将数据样本存在内存中,直到下一个时间戳到达,正确的类标签被展示出来,在这之后,样本和类标签的对应关系能被发现并且存储这种关系信息供以后使用。因此,对于给定的一段情节(episode),理想的表现会涉及到对第一个类的标签值(the first presentation of a class,我理解为类的值)的随机猜测,(因为标签被混洗了,不能根据之前的情节推断出正确的标签),并且之后使用记忆存储器来实现准确率的完美预测。最终,这个系统目标是对预测分布p进行建模,在每一个时间步引起相应的损失。

这个任务结构包含可利用的元知识:元学习的模型学习将数据表示绑定到其对应的正确标签,而不管数据表示或标签的实际内容如何,并且将采用一般方案将这些绑定表示(bound representations)映射到正确的类或用于预测的函数值。

  1. 记忆增强模型

3.1神经图灵机

神经图灵机是MANN一种完全不同的实现。他包括一个控制器,例如一个前馈网络或者LSTM,这和一个使用一些读写头的额外记忆模块相互影响。图灵机中的记忆模块的记忆单元编码和索引都是很快的,向量表示可能在每个时间步骤被放入或取出内存。这种能力使NTM称为元学习和短时预测完美的候选者,因为它既能通过慢的权重更新实现长期存储,并且通过额外记忆模块实现短期存储。如果NTM能够学习一种通用策略来将各种表示(representations,这里指内存单元中记录的信息)类型放入记忆单元中,并且能够学习之后如何使用这些表示来做预测,那么他可能利用他的速度来对仅见过一次的数据进行准确预测。

       我们模型中的控制器要么使用LSTMs或者前馈网络。控制器使用读写头与外部存储器模块交互,读写头分别用于从存储器中检索表示(representations)或将它们放入存储器中。给定一些输入xi,控制器生成一个键值kt,这个键值被存入记忆矩阵Mt的一行,或者被用于从一行中索引一个特定的记忆单元i,Mt(i),当索引一个记忆单元Mt的时候,会使用余弦相似度。

用于去产生读权重向量Wrt,根据以下公式计算得到

一个记忆单元rt,通过使用权重向量进行索引:

这个记忆单元的内容被控制器作为一个分类器的输入泪如softmax层的输入,或者作为下一个控制器状态的额外输入。

3.2、最少或最近使用的记忆信息

    在之前NTM的例子中,记忆信息通过内容或者位置被索引。基于位置的索引常常被用于迭代更新的步骤,就像沿着磁带跑一样,也回用于在记忆信息上的长距离跳跃。这种方法对于基于序列预测的任务是有优势的,可是这种方式对于强调独立于序列之外的信息的任务并不是最优的。因此,在我们的模型中,包含一个新设计的读取记忆信息的模式叫做LRUA。

    LRUA模型是一个纯粹的基于内容的记忆读写方式,记忆信息要么被写到斤少使用的记忆模块的位置或者最近使用的记忆模块的位置。这个模块看重有关信息的准确编码(吸收提取数据的知识),并且是完全的基于内容的索引。新的信息被写入到很少使用的位置或者写入到最后使用的位置,保存最近编码的信息,这可以用更加新的、可能更相关的信息更新的记忆信息。这两种方式的不同在于先前的读参数和使用参数(usage weights wtu),这些使用参数通过衰减参数逐步更新参数值,

这里,gama是衰减参数,读向量参数由前边计算出来,最少使用的权重能通过用户参数计算出来,其中m(v,n)表示前n个

n是读记忆的数目,写参数向量(write weights wtw)由以下方式计算得到:

σ(·) 是sigmoid函数,

记忆信息能够被写到标记为零记忆槽,或者之前被使用过的槽(slot),如果是之前使用过的槽,那么就是最少被使用的槽,并且原本槽里的记忆信息会被删除。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值