小样本学习&元学习经典论文整理||持续更新
核心思想
本文是在MAML的思路上进一步改进,提出一种基于参数优化的小样本学习算法Reptile。首先我们一起回忆一下MAML是如何进行元学习的,在之前的文章中,我们有提到MAML的训练可以分为两个层次:内层优化和外层优化,内层优化就与普通的训练一样,假设网络初始参数为
θ
0
\theta^0
θ0,在数据集
A
A
A上采用SGD的方式进行训练后得到参数
θ
′
\theta'
θ′。如果是普通的训练,那么就会接着采样一个数据集
B
B
B,然后以
θ
′
\theta'
θ′作为初始参数继续训练了。MAML同样采集一个数据集
B
B
B,然后用在数据集
A
A
A上训练得到的模型
f
θ
′
f_{\theta'}
fθ′处理数据集
B
B
B上的样本,并计算损失。不同的是,MAML利用该损失计算得到的梯度对
θ
0
\theta^0
θ0进行更新。
θ
1
=
θ
0
−
ε
▽
θ
∑
T
i
∈
B
L
T
i
(
f
θ
′
)
\theta^1=\theta^0-\varepsilon\triangledown _{\theta}\sum_{\mathcal{T}_i\in B}\mathcal{L}_{\mathcal{T}_i}(f_{\theta'})
θ1=θ0−ε▽θTi∈B∑LTi(fθ′) 也就是说MAML的目标是训练得到一个好的初始化参数
θ
\theta
θ,使其能够在处理其他任务时很快的收敛到一个较好的结果。在梯度计算过程中会涉及到二阶导数计算,MAML利用一阶导数近似方法(FOMAML)进行处理,发现结果相差并不大,但计算量会减少很多。回到本文,本文提出的算法就是在FOMAML,进一步简化参数更新的方式,甚至连损失梯度都不需要计算了,直接利用
θ
0
−
θ
′
\theta^0-\theta'
θ0−θ′(这里的符号使用的与原文不同,但表达含义相同)作为梯度对参数进行更新,即
θ
1
=
θ
0
−
ε
(
θ
0
−
θ
′
)
\theta^1=\theta^0-\varepsilon(\theta^0-\theta')
θ1=θ0−ε(θ0−θ′) 可能有人会觉得这样做,不是相当于退化成普通的训练过程了吗,因为
θ
′
\theta'
θ′还是利用SGD方式得到的,然后让
θ
0
\theta^0
θ0沿着
θ
0
−
θ
′
\theta^0-\theta'
θ0−θ′的方向更新,就得到
θ
1
\theta^1
θ1。如果说在训练数据集
A
A
A中只有一个训练样本,或者说只经过一个batch的训练,那么本文的算法的确会退化为普通的SGD训练,但如果每个数据集都进行不止一个Batch的训练,二者就不相同了。
如上图所示,本文在更新参数
θ
0
\theta^0
θ0时,会在数据集
A
A
A上做多个Batch的计算得到最终的参数
θ
′
\theta'
θ′(也就是图中的
θ
^
m
\hat{\theta}^m
θ^m),然后再回到
θ
0
\theta^0
θ0处,计算
θ
0
−
θ
′
\theta^0-\theta'
θ0−θ′,并更新参数得到
θ
1
\theta^1
θ1。而普通的SGD就不会回到
θ
0
\theta^0
θ0处了,而是继续以
θ
′
\theta'
θ′作为初始值进行更新了。MAML,本文提出的算法(Reptile)和普通的SGD(Pre-train)方法的比较如下图所示
假设训练集中包含两个Batch,则对于Pre-train模型会沿着
g
1
g_1
g1方向更新,MAML会沿着
g
2
g_2
g2的方向更新,Reptile则会沿着
g
1
+
g
2
g_1+g_2
g1+g2的方向更新,看起来像是传统方法和MAML结合的产物。作者还从数学的角度上,证明了Reptile与FOMAML在本质上是相同的,但是Reptile的计算效率和内存占用明显要优于FOMAML,原谅我才疏学浅,并没有看懂证明过程,有兴趣的读者请阅读原文。
创新点
- 本文在MAML的基础上提出一种更为简单的进行参数初始化的元学习算法,并从数学上证明了其与一阶近似MAML的等价性
算法评价
本文是一篇充满学术气息的文章,包含大量的数学证明,但其提出的算法却是非常简单的——直接用向量差作为梯度,而且通过数学和实验两种方式证明了,本文提出的算法Reptile性能与MAML非常接近。但是作者同样也发现,本文的算法虽然在分类任务中取得了较好的结果,但是在强化学习中,却没有MAML表现的那样好,这也是一个值得探究的方向。这里推荐台湾大学李宏毅教授的讲解视频,他生动形象地介绍了MAML和Reptile算法,给了我很大的帮助和启迪,B站视频链接。
如果大家对于深度学习与计算机视觉领域感兴趣,希望获得更多的知识分享与最新的论文解读,欢迎关注我的个人公众号“深视”。