小样本学习&元学习经典论文整理||持续更新
核心思想
本文提出一种带有记忆增强神经网络(Memory-Augmented Neural Networks,MANN)的元学习算法用于解决小样本学习问题。我们知道LSTM能够通过遗忘门有选择的保留部分先前样本的信息(长期记忆),也可以通过输入门获得当前样本的信息(短期记忆),这一记忆的方式是利用权重的更新隐式实现的。而在本文中,作者希望利用外部的内存空间显式地记录一些信息,使其结合神经网络自身具备的长期记忆能力共同实现小样本学习任务。
如图(a)所示,整个训练过程分成多个Episode,每个Episode中包含若干个样本
x
x
x和对应的标签
y
y
y,将所有的样本组合成一个序列,
x
t
x_t
xt表示在
t
t
t时刻输入的样本,
y
t
y_t
yt表示与之对应的标记,但要注意的是输入时
x
t
x_t
xt和
y
t
y_t
yt并不是一一对应的,而是错位对应,即
(
x
t
,
y
t
−
1
)
(x_t,y_{t-1})
(xt,yt−1)为
t
t
t时刻的输入,这样做的目的是让网络有目的的记住先前输入的信息,因为只有保留有效的信息在下次再遇到同类样本才能计算得到对应的损失。在每个Episode之间样本序列都是被打乱的,这是为了避免网络在训练过程中慢慢记住了每个样本对应的位置,这不是我们希望的。可以看到对于每个单元(图中灰色的矩形块)他的输入信息既有前一个单元输出的信息,又有当前输入的信息,而输出一方面要预测当前输入样本的类别,又要将信息传递给下个时刻的单元,这与LSTM或RNN很相似。在此基础上作者增加了一个外部记忆模块(如图(b)中的蓝色方框),他用来储存在当前Eposide中所有"看过"的样本的特征信息。怎样去理解他呢?比如网络第一次看到一张狗的照片,他并不能识别出它是什么,但是他把一些关键的特征信息记录下来了,而且在下个时刻网络得知了它的类别标签是狗,此时网络将特征信息与对应的标签紧紧地联系起来(Bind),当网络下次看到狗的照片时,他用此时的特征信息与记忆模块中储存的特征信息进行匹配(Retrieve,真实的实现过程并不是匹配,而是通过回归的方式获取信息,此处只是方便大家理解),这样就很容易知道这是一只狗了。这一过程其实与人类的学习模式非常接近了,但作者是如何利用神经网络实现这一过程的呢?作者引入了神经图灵机(NTM),为了方便下面的讲解此处需要先介绍一下NTM。
神经图灵机的结构如上图所示,它由控制器(Controller)和记忆模块(Memory)构成,控制器利用写头(Write Heads)向记忆模块中写入信息,利用读头(Read Heads)从记忆模块中读取信息。回到本文的模型中,作者用一个LSTM或前向神经网络作为控制器,用一个矩阵
M
t
M_t
Mt作为记忆模块。给定一个输入
x
t
x_t
xt,控制器输出一个对应的键(Key)
k
t
k_t
kt,可以理解为是一个特征向量,这个特征向量一方面要通过写头写入记忆模块,一方面又要通过读头匹配记忆模块中的信息,用于完成分类任务或回归任务。我们先介绍读的过程,假设记忆模块中已经储存了许多的特征信息了,每个特征信息就是矩阵中的一行(特别注意,此处一行不是代表一个特征向量,而是某种抽象的特征。写入的过程并不是将特征向量一行一行地堆放到记忆模块中,写入的过程远比这个复杂),此时我们要计算当前特征向量
k
t
k_t
kt与记忆模块
M
t
M_t
Mt中的各个向量之间的余弦距离
D
(
k
t
,
M
t
(
i
)
)
D(k_t,M_t(i))
D(kt,Mt(i))(原文中用
K
K
K表示,为了避免与
k
t
k_t
kt混淆,特此改为
D
D
D),然后利用softmax函数将其转化为读取权重
w
t
r
(
i
)
w^r_t(i)
wtr(i),最后利用回归的方式(加权求和)计算得到提取出来的记忆
r
t
=
∑
i
w
t
r
(
i
)
M
t
(
i
)
r_t=\sum_iw^r_t(i)M_t(i)
rt=∑iwtr(i)Mt(i)。控制器一方面将
r
t
r_t
rt输入到分类器(如softmax输出层)中获取当前样本的类别,另一方面将其作为下一时刻控制器的一个输入。
写的过程就是描述如何合理有效的将当前提取的特征信息存储到记忆模块中。作者采用了最少最近使用方法(Least Recently Used Access,LRUA),具体而言就是倾向于将特征信息存储到使用次数较少的记忆矩阵位置,为了保护最近写入的信息;或者写入最近刚刚读取过的记忆矩阵位置,因为相邻两个样本之间可能存在一些相关信息。写入的方法也是为记忆模块中的每一行计算一个写入权重
w
t
w
(
i
)
w^w_t(i)
wtw(i),然后将特征向量
k
t
k_t
kt乘以对应权重,在加上先前该位置保存的信息
M
t
−
1
(
i
)
M_{t-1}(i)
Mt−1(i)得到当前时刻的记忆矩阵
M
t
(
i
)
=
M
t
−
1
(
i
)
+
w
t
w
(
i
)
k
t
M_t(i)=M_{t-1}(i)+w^w_t(i)k_t
Mt(i)=Mt−1(i)+wtw(i)kt。而写入权重
w
t
w
w^w_t
wtw计算过程如下
其中
w
t
−
1
r
w^r_{t-1}
wt−1r表示上一时刻读取权重,该值由读的过程计算得到,权重越大表示上一时刻刚刚读取过这一位置储存的信息;
σ
(
)
\sigma()
σ()表示sigmoid函数,
α
\alpha
α表示一个门参数,用于控制两个权重的比例。
w
t
−
1
l
u
w^{lu}_{t-1}
wt−1lu表示上一时刻最少使用权重,其计算过程如下
其中,
m
(
w
t
u
,
n
)
m(w_t^u,n)
m(wtu,n)表示向量
w
t
u
w_t^u
wtu中第
n
n
n个最小的值,
n
n
n表示内存读取次数,
w
t
u
w_t^u
wtu表示使用权重,其计算过程如下
包含三个部分,上个时刻的使用权重
w
t
−
1
u
w_{t-1}^u
wt−1u,
γ
\gamma
γ是衰减系数,读取权重
w
t
r
w^r_t
wtr和写入权重
w
t
w
w^w_t
wtw,当
w
t
u
(
i
)
w_t^u(i)
wtu(i)小于
m
(
w
t
u
,
n
)
m(w_t^u,n)
m(wtu,n)时表示位置
i
i
i是使用次数最少的位置之一,那么在下次写入时,使用该位置的概率就更高。
作者称该模型是一种元学习算法,那是如何体现元学习过程的呢?我的理解是控制机本身是任务学习器(Learner),用于提取特征信息并预测分类,而整个模型则是一个元学习器(Meta-learner)用于学习如何将信息写入/读出记忆模块。
实现过程
网络结构&损失函数&训练策略
这部分内容论文中没有特别具体的介绍,本身也不重要,核心在于整个模型的思想,具体的结构和损失函数可以结合任务需求自行选定。
网络推广
该模型可以应用于分类和回归任务。
创新点
- 设计了一种带有记忆增强神经网络的元学习算法,结合长期记忆和短期记忆两方面优势,能够在看过某种类型的图片一眼(one-shot),就能在下次遇到同类图片时很快识别出来
- 利用神经图灵机模型实现了记忆增强网络,写入的过程将特征信息与对应标签紧密关联起来,读取的过程又将特征向量准确分类
算法评价
该算法巧妙的将NTM应用于小样本学习任务中,采用显示的外部记忆模块保留样本特征信息,并利用元学习算法优化NTM的读取和写入过程,最终实现有效的小样本分类和回归。文中提到的长期记忆是通过控制器网络权重参数的更新实现的,因为采用了错位配对的方式,因此要到第二次见到该类别的图像时才能得到相应的损失,并进行反向传递,因此权重更新过程是非常缓慢的,能够保留很久之前的信息(如果权重更新速度很快,可能为了识别新的图片类别,就迅速忘记之前识别过的图片了)。短期记忆是由外部记忆模块实现的,有人可能会觉得这个记忆模块不是随着训练过程不断储存各个时刻的信息吗?为什么叫做短期记忆呢?这是因为作者在两个Eposide之间会清除记忆模块,以避免两个Eposide记忆之间相互干扰,而一个Eposide只是有若干个类别的少量样本构成的,相对于整个学习过程他仍然属于短期记忆。该算法整个思想都非常的新颖,NTM模型也十分的巧妙,作者自己也认为非常接近人类学习认知的模式了,但不知道是不是因为训练比较困难的原因,该方法并没有大规模的推广。在学习该文章时,有必要提前了解一下NTM模型的原理,否则学习起来会比较困难。
如果大家对于深度学习与计算机视觉领域感兴趣,希望获得更多的知识分享与最新的论文解读,欢迎关注我的个人公众号“深视”。