小样本学习&元学习经典论文整理||持续更新
核心思想
本文提出一种采用元学习的方式解决小样本学习任务的方法。作者首先指出为什么普通的深度学习算法在小样本学习任务中表现不佳,他认为有两个方面的原因:第一,基于梯度的优化算法(如momentum,ADAM)不是为小样本学习任务设计的,其需要大量的样本经过数百万次迭代才能收敛到一个较好的结果;第二,对于每个单独的数据集,网络都需要从一个随机的初始化状态开始更新参数,这导致他无法在只经过几次迭代的情况下就很好的收敛。针对上述问题,作者提出使用元学习的方式来解决小样本学习任务,即将分类任务中原本需要人为设定的参数(如学习率,衰减率,初始权重等)看做需要学习的参数,通过学习的方式优化这些参数。因此整个任务可以分成两个级别的学习器,一个是元学习器(meta-learner)用于学习上面提到的优化参数;另一个是任务学习器(learner)用于学习分类器中的权重参数。作者认为元学习能够同时捕捉每个单独任务中的短期知识和整个任务中共同具备的长期知识,因此能够引导分类器快速的收敛到一个较好的结果。
整个元学习的数据集可以分成三个部分:元训练集
S
m
e
t
a
−
t
r
a
i
n
S_{meta-train}
Smeta−train,元测试集
S
m
e
t
a
−
t
e
s
t
S_{meta-test}
Smeta−test和元验证集
S
m
e
t
a
−
v
a
l
i
d
a
t
i
o
n
S_{meta-validation}
Smeta−validation。其中,元训练集
S
m
e
t
a
−
t
r
a
i
n
S_{meta-train}
Smeta−train又分为训练集
D
t
r
a
i
n
D_{train}
Dtrain和测试集
D
t
e
s
t
D_{test}
Dtest两个部分,训练集
D
t
r
a
i
n
D_{train}
Dtrain用于训练分类器网络,目标是提高测试集
D
t
e
s
t
D_{test}
Dtest上的分类准确率,而测试集
D
t
e
s
t
D_{test}
Dtest得到的损失用于训练元学习器,更新网络的优化参数。元测试集
S
m
e
t
a
−
t
e
s
t
S_{meta-test}
Smeta−test用于评估网络的学习效果,元验证集
S
m
e
t
a
−
v
a
l
i
d
a
t
i
o
n
S_{meta-validation}
Smeta−validation用于调整元学习器的超参数。
对于普通的分类学习任务而言,其参数更新策略如下
θ
t
=
θ
t
−
1
−
α
t
▽
θ
t
−
1
£
t
\theta _{t}=\theta _{t-1}-\alpha _{t}\triangledown _{\theta _{t-1}}\text{\pounds}_{t}
θt=θt−1−αt▽θt−1£t其中
θ
t
\theta _{t}
θt表示第
t
t
t次迭代的权重参数,
θ
t
\theta _{t}
θt表示学习率,
▽
θ
t
−
1
£
t
\triangledown _{\theta _{t-1}}\text{\pounds}_{t}
▽θt−1£t表示损失函数
£
t
\text{\pounds}_{t}
£t对于权重参数
θ
t
−
1
\theta _{t-1}
θt−1的梯度。而这一过程与LSTM的学习过程非常相似,如下式所示
c
t
=
f
t
⊙
c
t
−
1
+
i
t
⊙
c
~
t
c_t=f_t\odot c_{t-1}+i_t\odot \tilde{c}_t
ct=ft⊙ct−1+it⊙c~t只需要令遗忘门
f
t
=
1
f_t=1
ft=1,细胞状态
c
t
=
θ
t
c_t=\theta _{t}
ct=θt,输入门
i
t
=
θ
t
i_t=\theta _{t}
it=θt,候选细胞状态
θ
t
=
θ
t
−
1
\theta _{t}=\theta _{t-1}
θt=θt−1。因此作者提出使用LSTM作为元学习器,把学习率
i
t
=
θ
t
i_t=\theta _{t}
it=θt,衰减率
f
t
f_t
ft,和初始化权重
c
0
c_0
c0作为用于学习的参数进行训练。
实现过程
网络结构
对于分类器,也就是任务学习器(learner)其网络结构并没有严格的限制,理论上任何采用梯度下降法学习的分类器都可以,作者为了与baseline Matching Netork作比较,因此也采用了相同的四级卷积神经网络结构,最后加一级全连接层用于计算各个类别的概率。对于LSTM网络,也就是元学习器(meta-learner)作者设计了一个两层网络,第一层就是一个普通的LSTM网络,第二层则是本文改进的LSTM元学习器。其中输入门
i
t
i_t
it和遗忘门
f
t
f_t
ft分别定义为以下形式
损失函数
本文采用了分类器正确分类的概率值的负对数的平均值作为损失。
训练策略
对于一个k-shot,N-class的分类任务,每个训练集 D t r a i n D_{train} Dtrain中包含N个类别,每个类别包含k个样本,每个测试集 D t e s t D_{test} Dtest中包含若干个类别和若干个样本,沿用Matching Network中的定义,将一个训练集 D t r a i n D_{train} Dtrain和一个测试集 D t e s t D_{test} Dtest合起来称为一个Eposide,而在整个元训练集 S m e t a − t r a i n S_{meta-train} Smeta−train中包含许多个Eposide。训练时,首先在一个Eposide内,在训练集 D t r a i n D_{train} Dtrain上反复迭代,更新任务学习器的权重参数。然后在测试集 D t e s t D_{test} Dtest上计算损失,用于更新元学习器的参数,并将更新后的参数用于下一个Eposide中任务学习器的训练。在整个元训练集 S m e t a − t r a i n S_{meta-train} Smeta−train反复迭代,得到最终的任务学习器参数和元学习器参数。
网络细节
- 参数共享。因为元学习器需要更新整个神经网络(任务学习器)中成千上万个参数,为了避免参数爆炸,作者采用了参数共享策略,即对于任务学习器中各个特征图上的每个像素点都采用相同的更新参数
(学习率,衰减率等)。又因为不同位置上的梯度和损失可能有很大的区别,作者引入了一种归一化的方式来避免这一差别带来的不良影响
- 梯度独立性假设。作者做了一个简化的假设条件,任务学习器的损失对于元学习器的梯度影响是不重要的可以忽略,这样可以避免求二阶导数,降低了计算复杂度。
- 元学习器LSTM的初始化。将输入门 i t i_t it的偏置 b I \textbf{b}_I bI初始化为一个较小的值,可以使得初始学习率较小,提高训练的稳定性。
- 批规范化。在元学习的过程中必须非常小心的使用批规范化,因为不想将在元测试的过程中的统计数据(均值和方差)泄漏到其他的Eposide中去。作者采用一种巧妙的策略避免了这一问题,在元测试过程中只收集保留当前Eposide的统计数据,当进入下一个Eposide时将数据清空。
创新点
- 采用元学习的方式,利用LSTM网络结构学习分类器的优化参数,使其符合小样本学习任务的需求
- 在网络设计过程中,采用了参数共享、梯度独立性假设的简化条件,降低了训练的复杂度
- 在参数初始化和批规范化设计中都采取了有针对性的解决方案,提高了训练的稳定性
算法评价
为解决小样本学习任务,作者提出一种基于元学习的解决方案,由试验结果可以看出,该方法在one-shot任务中与Matching Network的准确率相差无几,但在5-shot任务中则取得了比Matching Network更加优异的结果。该文章为小样本学习又指明一条新的道路,也是目前许多小样本学习采用的元学习方案。