论文地址:proceedings.mlr.press/v70/finn17a/finn17a.pdf
5.1 简介
Model-Agnostic:可适用于任何梯度下降
的模型,可用于不同的学习任务
(如分类、回归、策略梯度RL)。
Meta-Learning:在大量的学习任务上训练模型,从而让模型仅用小数量的训练样本
就可以学习新任务
(加速fine-tune)。不同的任务有不同的模型
。
需要考虑将先前的经验与少量新信息融合,同时避免过拟合。
方法的核心是训练模型的初始参数
,从而使模型仅在少量新任务样本上通过几步梯度更新就达到最好性能。
5.2 方法
首先通过如下算法1初始化网络权重,然后在新任务上微调训练。
对上述算法1的解读:
MAML的目的是:学习网络的初始化权重,从而使网络只在新任务上训练一步或几步就能达到很好的效果
。
和深度学习的核心一样,算法1中的训练任务
和初始化权重后的微调的测试任务
都是样本,和一般深度学习中的训练集
和测试集
的目的和概念一样。
对上述描述进行公式化:假设网络的初始化权重为
θ
\theta
θ,网络在不同新任务
τ
i
\tau_i
τi的训练集上经过一步梯度更新后的权重为
θ
i
′
\theta'_i
θi′,使用更新后的权重
θ
i
′
\theta'_i
θi′在新任务
τ
i
\tau_i
τi的测试集上计算损失
L
i
(
θ
i
′
)
\mathcal{L}_i (\theta'_i)
Li(θi′),MAML的目的是使不同新任务
τ
i
\tau_i
τi上的损失之和最小,公式如下:
L
=
m
i
n
∑
τ
i
∼
p
(
τ
)
L
i
(
θ
i
′
)
L = min ~ \sum_{\tau_i \sim p(\tau)} \mathcal{L}_i (\theta'_i)
L=min τi∼p(τ)∑Li(θi′)
以上述为总损失函数,对网络权重
θ
\theta
θ进行梯度下降,如下:
θ
←
θ
−
β
∇
θ
∑
τ
i
∼
p
(
τ
)
L
i
(
θ
i
′
)
=
θ
−
β
∑
τ
i
∼
p
(
τ
)
∇
θ
L
i
(
θ
i
′
)
\theta \leftarrow \theta - \beta \nabla_{\theta} \sum_{\tau_i \sim p(\tau)} \mathcal{L}_i (\theta'_i) \\ ~~ = \theta - \beta \sum_{\tau_i \sim p(\tau)} \nabla_{\theta} \mathcal{L}_i (\theta'_i)
θ←θ−β∇θτi∼p(τ)∑Li(θi′) =θ−βτi∼p(τ)∑∇θLi(θi′)
计算
∇
θ
L
i
(
θ
i
′
)
\nabla_{\theta} \mathcal{L}_i (\theta'_i)
∇θLi(θi′):
借用李宏毅老师讲义中的公式, ϕ = θ \phi=\theta ϕ=θ, θ ^ = θ i ′ \hat{\theta}=\theta'_i θ^=θi′, ∇ θ L i ( θ i ′ ) = ∇ ϕ l ( θ ^ ) \nabla_{\theta} \mathcal{L}_i (\theta'_i) = \nabla_{\phi} l(\hat\theta) ∇θLi(θi′)=∇ϕl(θ^), ∇ ϕ l ( θ ^ ) \nabla_{\phi} l(\hat\theta) ∇ϕl(θ^)可以分解为如下公式,
其中,
θ
^
\hat{\theta}
θ^由
ϕ
\phi
ϕ计算得到,如下:
通过如下公式计算
∇
ϕ
l
(
θ
^
)
\nabla_{\phi} l(\hat\theta)
∇ϕl(θ^)中的每一项导数:
计算二阶导数非常耗时,所以MAML论文中提出使用一阶导数近似方法,即假设二阶导数都为0,对公式简化如下:
简化后,
∇
ϕ
l
(
θ
^
)
→
∇
θ
^
l
(
θ
^
)
\nabla_{\phi} l(\hat\theta) \rightarrow \nabla_{\hat\theta} l(\hat\theta)
∇ϕl(θ^)→∇θ^l(θ^),原梯度下降公式转化为:
θ
←
θ
−
β
∑
τ
i
∼
p
(
τ
)
∇
θ
i
′
L
i
(
θ
i
′
)
\theta \leftarrow \theta - \beta \sum_{\tau_i \sim p(\tau)} \nabla_{\theta'_i} \mathcal{L}_i (\theta'_i)
θ←θ−βτi∼p(τ)∑∇θi′Li(θi′)
即,直接对每个更新后的
θ
i
′
\theta'_i
θi′计算梯度,将梯度作用到更新前的
θ
\theta
θ上。
问题:
1、为什么循环随机采样多个任务进行学习?
答:构建足够多的不同任务,使网络得到充分训练,从而在面向新任务时只通过几步更新就能达到较好的效果。
2、为什么第一次计算梯度与第二次计算梯度使用相同任务下的不同样本,即support set和query set?
答:前者是训练集,用于计算得到 θ i ′ \theta'_i θi′,后者是测试集,用于计算损失。
3、相比于先在一大堆任务上预训练(每次只计算一次梯度),再在新任务上微调,优势是什么?
答:预训练的目的是使网络在所有任务上的性能达到最优,将这个最优模型用于新任务微调时,可能陷入局部最优值等问题;而MAML的目的是使模型在新任务上训练几步后的性能达到最优,考虑的是未来的最优值,因此不会在某些任务上达到最优,而在其他任务上陷入次优。
更多细节请参考:
https://zhuanlan.zhihu.com/p/57864886
https://www.bilibili.com/video/BV1w4411872t?p=7&vd_source=383540c0e1a6565a222833cc51962ed9