小样本学习&元学习经典论文整理||持续更新
核心思想
本文提出一种终生记忆模块(life-long memory module)能够使得许多神经网络实现单样本学习,核心思想还是在训练时将每个类别样本的特征信息与对应的标签值保存下来,测试时利用最近邻思想,选择与查询样例最接近的K个样本,并据此预测插叙样例的标签。记忆模块由“键-值”对构成,“键”
K
K
K是神经网络特定层的输出,“值”
V
V
V则是给定样本对应的标签,此外还有一个额外的向量
A
A
A用于保存各个“键-值”对的“年龄”,记忆模块
M
M
M如下式
给定一个查询向量
q
q
q,且
q
q
q是经过归一化处理的,则
M
M
M中
q
q
q的最近邻被定义为其“键”与
q
q
q的内积最大的那个“键-值”对,如下式所示
因为“键”和查询向量都是经过归一化处理的,因此上式可以等价为计算余弦相似性。进一步拓展计算
k
k
k个最近邻,并按照由近到远的顺序排列
得到最主要的结果
V
[
n
1
]
V[n_1]
V[n1],并计算余弦相似性
d
i
=
q
⋅
K
[
n
i
]
d_i=q\cdot K[n_i]
di=q⋅K[ni],进一步得到
s
o
f
t
m
a
x
(
d
i
⋅
t
,
.
.
.
,
d
k
⋅
t
)
softmax(d_i\cdot t,...,d_k\cdot t)
softmax(di⋅t,...,dk⋅t),其中
t
t
t表示softmax温度的倒数,本文取
t
=
40
t=40
t=40。文设计的记忆模块是能够训练的,那么记忆模块是如何进行训练和更新的呢?
每当输入一个新的查询向量
q
q
q,假设其真实标签为
v
v
v,经计算得到的最近邻为
n
1
n_1
n1。如果
n
1
n_1
n1对应的值
V
[
n
1
]
=
v
V[n_1]=v
V[n1]=v,则是需要将
n
1
n_1
n1对应的键进行更新,如下式
并且将对应的年龄向量
A
[
n
1
]
A[n_1]
A[n1]更新为0。如果
V
[
n
1
]
≠
v
V[n_1]\neq v
V[n1]=v,那么就需要将新的键值对
(
q
,
v
)
(q,v)
(q,v)写入记忆模块,写到哪里呢?这需要从年龄最大的项里面随机选择一个
n
′
n'
n′(年龄越大表示越长时间没有被更新过了),然后更新对应的值
最后把所有没被更新的“键-值”对其年龄都加1。
使用时需要考虑如何高效地计算最近邻,假设一个小批次的查询向量构成矩阵
Q
=
[
q
1
,
.
.
.
,
q
b
]
Q=[q_1,...,q_b]
Q=[q1,...,qb],只需要计算一个矩阵乘法
Q
×
K
T
Q\times K^T
Q×KT就能得到对应的距离矩阵。如果精确计算模式还是太慢,可以使用局部敏感哈希(LSH)近似计算最近邻。首先随机选择一些经过规范化处理的哈希向量
h
1
,
.
.
.
h
l
h_1,...h_l
h1,...hl,则查询向量
q
q
q对应的哈希编码为一串二进制数字
b
1
,
.
.
.
,
b
l
b_1,...,b_l
b1,...,bl,其中
b
i
=
1
b_i=1
bi=1当且仅当
q
⋅
h
i
>
0
q\cdot h_i>0
q⋅hi>0。这样我们可以得到所有“键”和查询向量对应的哈希编码,如果两个向量对应的哈希编码中相同的位越多,则表示二者相似的可能性越大,因此在计算最近邻时,只需要对哈希编码相同的向量进行计算。
对于采用卷积神经网络实现的小样本分类任务,运用本文提出的记忆模块的方式非常简单,将最后一层卷积层输出的向量作为查询向量,计算与记忆模块中的最近邻作为预测的结果。
实现过程
网络结构
本文设计的记忆模块本身不具备网络结构,可以配合各种任务网络使用。
损失函数
本文提出一种记忆损失用于提高特征向量的表征能力,对于查询向量
q
q
q和对应的标签
v
v
v,首先计算
k
k
k个最近邻中,类别相同的“键-值”对的最小索引值
p
,
V
[
n
p
]
=
v
p, V[n_p]=v
p,V[np]=v,和类别不同“键-值”对中的最小索引值
b
,
V
[
n
b
]
≠
v
b, V[n_b]\neq v
b,V[nb]=v,则记忆损失为
式中
α
\alpha
α表示阈值参数,本文取
α
=
0.1
\alpha=0.1
α=0.1。因为对于相同的两项余弦相似性最大,所以记忆损失函数的目的就是,使正确的“键”的余弦相似度最大化,使错误的“键”的余弦相似度最小化,且当两者之间的差距超过一定的阈值时,就不再传递损失了。
算法推广
本文除了可以与CNN结合应用于图像分类任务外,还可以与LSTM等结构结合,应用于机器翻译任务,此处只介绍与Google Neural Machine Translation (GNMT) 模型结合的方式。
如图所示GMMT包含:编码器,注意力模块和解码器三个部分,在结合记忆模块时,保留编码器部分不动,将注意力模块输出的向量作为查询向量保存在记忆模块中。在GMMT模型中,注意力模块的输出会用于解码器的每个LSTM模块中(除了第二个),因此记忆模块中向量的取用也会并行的应用于每个LSTM模块。在最后的softmax层之前,将记忆模块的输出和最后一个LSTM模块的输出利用一个线性层结合起来。
创新点
- 提出一种记忆模块,可以结合CNN,RNN等多种模型,通过一种终生记忆的方式实现单样本学习任务
- 提出一种记忆损失函数和更新方法,使得记忆模块可以作为一种额外的监督,提高特征提取网络的表征能力
- 提出一种高效的最近邻近似计算方法,提高了计算效率
算法评价
本文是采用外部记忆模块实现小样本学习的策略,其记忆模块的损失计算和更新策略(特别是年龄向量的引入)还是很有新意的。而且该模块的适用性很强,能够与不同任务,不同类型的神经网络相结合。但该模块也有许多能够进一步改进的地方,包括距离度量方式,最近邻选择的方法,记忆模块的更新方式等。
如果大家对于深度学习与计算机视觉领域感兴趣,希望获得更多的知识分享与最新的论文解读,欢迎关注我的个人公众号“深视”。