论文:[Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](# Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks)
台大李宏毅老师的视频课程:Meata learning: MAML
1. Meta Learning
Few-shot leanring的方法可以分为三类:
- metric-based: Siamese Networks, Matching Networks, Prototypical Networks
- Model-based: Meta Networks, Memory-Augmented Neural Networks
- Optimized-based: MAML
Model-based方法针对few-shot learning问题特别设计了模型,而MAML允许使用任何模型,所以叫做model-agnostic。
Meta learning可以称为是一种"learn to learn"的学习方法。以往的机器学习任务都是教给模型学习一个任务,但在meta learning中,一次让模型处理好几种任务,让模型学习“学习这件事”,未来有新的任务后模型能够很快进行处理。
Machine learning: machine learning的目的是从训练数据中到一个函数
f
f
f,即
f
(
x
)
⟶
y
f(x)\longrightarrow y
f(x)⟶y。
Meta learning: meta learning的目的是从一系列任务中找到一个函数
F
F
F,该函数
F
F
F具有找到新任务函数
f
f
f的能力,这一个函数f其实是目标任务模型的参数,即
F
(
x
)
⟶
f
F(x)\longrightarrow f
F(x)⟶f。
Meta learning的方法可以概括为:
- 定义一组学习算法 F F F
- 定义一个loss function, 用来评价哪一个学习算法比较好(这里的loss就不再是评价预测和标签是否一致,而是预测 f f f与期望 f ∗ f^{*} f∗是否一致)
- 使用梯度下降找到最好的学习算法
2.How MAML works
2.1 Define algorithm F F F
一般在有监督学习中,我们假设模型参数为
ϕ
∈
Φ
\phi \in \Phi
ϕ∈Φ,目标函数为:
min
ϕ
L
τ
(
ϕ
)
\underset{\phi}{\text{min}} \, \mathcal{L}_\tau (\phi)
ϕminLτ(ϕ)
梯度更新的表达式为:
ϕ
←
ϕ
−
α
∇
ϕ
L
τ
(
ϕ
)
,
\phi \leftarrow \phi - \alpha \nabla_\phi \mathcal{L}_\tau (\phi) ,
ϕ←ϕ−α∇ϕLτ(ϕ),
但是使用这样的参数更新方法来处理few-shot问题会导致over-fitting,因为训练样本数量太少而模型参数太多。那么我们是否可以做一个假设,用其他相似的任务对模型进行预训练,然后再将预训练模型应用于目标任务。也就是我们想要从大量的任务中训练一个具有较好泛化性的模型,该模型的参数可以快速适应到其它任务,这就是meta learning。
2.2 Define loss function
在传统的机器学习算法中,我们需要一个loss function来评价预测结果好坏。同样在meta learning中,我们也需要一个loss function来评价初始化参数的好坏。但是模型参数那么多,如何评价呢?我们可以用相同task的测试集的数据,如果测试集的数据得到的loss小,就表明该初始化参数比较好,反之亦然。
我们用task 1来训练学习算法(learning algorithm) F F F,会得到一个函数 f 1 f_1 f1。为了评价 f 1 f_1 f1好不好,我们可以将task 1的测试数据输入到模型中,得到测试结果的loss值 l 1 l_1 l1。模型在一个任务上表现好还不够,我们希望模型在很多个任务上都有很好的表现,因此,我们还要测试 F F F在task2,task3,task4……的表现。相应的,我们得到了不同任务的 l 2 l_2 l2, l 3 l_3 l3, l 4 l_4 l4……
所有的任务测试完成之后,我们将所有的loss加起来来评估
F
F
F的好坏:
L
(
F
)
=
∑
n
=
1
N
l
n
L(F)=\sum_{n=1}^{N}l_n
L(F)=n=1∑Nln
其中,
N
N
N代表task的数量,
l
n
l_n
ln是第
n
n
n个任务的test loss。
Machine Learning:
- training data
- test data
Meta Learning:
- training tasks
- training data
- test data
- test tasks
- training data
- test data
不过在few-shot learning中,我们通常把training data叫做support set,test data叫做query set。为了与传统的machine learning算法中的training data和test data进行区分,我后面都有叫做support set和query set。
在meta learning中我们不仅要寻找对所有任务都最优的初始化参数,当遇到新任务的时候也要可以微调模型适应新的任务(we will not simply use the data from other tasks to find parameters that are optimal for all tasks, but keep the option to fine-tune our model)。因此优化目标就可以写作:
min
θ
E
τ
[
L
τ
(
F
τ
(
θ
)
)
]
,
\underset{\theta}{\text{min}} \, \mathbb{E}_\tau [ \mathcal{L}_\tau (F_\tau(\theta)) ] ,
θminEτ[Lτ(Fτ(θ))],
其中,
F
τ
:
Φ
→
Φ
F_\tau: \Phi \rightarrow \Phi
Fτ:Φ→Φ是一个将
θ
\theta
θ映射到新的参数向量
F
τ
(
θ
)
F_\tau(\theta)
Fτ(θ)的一个优化算法,并且
F
τ
F_\tau
Fτ可以更具梯度下降算法更新。
θ
\theta
θ是由一些列任务学习得到的,它可以被认为是优化器
F
τ
F_\tau
Fτ的初始化参数,因此是目标任务的元参数(meta-parameter),优化元参数的过程叫做元学习(meta learning)。如果我们能找到一个最优的元参数
θ
\theta
θ,我们就可以使用很少的数据fine-tune任何任务而不会过拟合。
更简单一点的表示可以记作(来自李宏毅老师):
- 定义评价
F
F
F的损失函数:
L ( F ) = ∑ n = 1 N l n L(F)=\sum_{n=1}^{N}l_n L(F)=n=1∑Nln - 找到最优的
F
∗
F^*
F∗:
F ∗ = a r g min F L ( F ) F^*=\mathrm{arg} \min_F L(F) F∗=argFminL(F)
将找到的 F ∗ F^* F∗应用到测试任务中,将测试任务的suppor set输入到 F ∗ F^* F∗,模型会找到一个 f ∗ f^* f∗,将query set输入到 f ∗ f^* f∗中进行测试,可以得到query set的loss,这个loss就是meta learning训练完成后模型的好坏。
2. Model-Agnostic Meta-Learning
Few-shot learning的目标函数:
min
θ
E
τ
[
L
τ
(
F
τ
n
(
θ
)
)
]
,
n
>
0
\underset{\theta}{\text{min}} \, \mathbb{E}_\tau [ \mathcal{L}_\tau (F^{n}_\tau(\theta)) ] , n>0
θminEτ[Lτ(Fτn(θ))],n>0
我们令
τ
\tau
τ为训练任务,
τ
\tau
τ服从分布
τ
i
∼
p
(
τ
)
\tau_i \sim p(\tau)
τi∼p(τ)。任务
τ
\tau
τ是从数据集中随即选取的任务,因此它是一个随即变量。
MAML算法的流程如下:
- 首先初始化参数 θ \theta θ
- 随机选择一部分任务 τ i ∼ p ( τ ) \tau_i \sim p(\tau) τi∼p(τ)
- 对每一个任务使用梯度更新方法 θ i ′ = θ − α ∇ θ L τ i f ( θ ) \theta _{i}^{'}=\theta -\alpha \nabla _{\theta }\mathcal{L} _{\tau _i}f(\theta ) θi′=θ−α∇θLτif(θ),计算得到该任务下的最优的 θ i ′ \theta _{i}^{'} θi′作为初始化参数
- 现在我们已经得到了每个任务的初始化参数 θ i ′ \theta _{i}^{'} θi′,我们要评价学习到的初始化参数 θ i ′ \theta _{i}^{'} θi′的好坏,需要借助test data。为了评估training data或者说是(support set)训练得到的初始化参数 θ ′ \theta ^{'} θ′的好坏,将每个任务的test data(或者说是query set)输入到训练后的模型 f θ ′ f_{\theta^{'}} fθ′中,计算损失函数 L τ i ( f θ i ′ ) \mathcal{L} _{\tau _i}(f_{\theta ^{'}_{i}}) Lτi(fθi′)。因此,优化目标就变成了: min θ ∑ τ i ∼ p ( τ ) L τ i ( f θ i ′ ) \min_{\theta }\sum_{\tau _i\sim p(\tau)}\mathcal{L} _{\tau _i}(f_{\theta ^{'}_{i}}) θminτi∼p(τ)∑Lτi(fθi′)根据步骤3中的公式我们可以得到可以看到: min θ ∑ τ i ∼ p ( τ ) L τ i ( f θ i ′ ) = ∑ τ i ∼ p ( τ ) L τ i ( f θ − α ∇ θ L τ i f ( θ ) ) \min_{\theta }\sum_{\tau _i\sim p(\tau)}\mathcal{L} _{\tau _i}(f_{\theta ^{'}_{i}})=\sum_{\tau _i\sim p(\tau)}\mathcal{L}_{\tau _i}(f_{\theta -\alpha \nabla _{\theta }\mathcal{L} _{\tau _i}f(\theta )}) θminτi∼p(τ)∑Lτi(fθi′)=τi∼p(τ)∑Lτi(fθ−α∇θLτif(θ))这是一个关于 θ \theta θ求两次梯度的目标函数。则meta-gradient update就可以写成: θ ← θ − β ∇ θ ∑ τ i ∼ p ( τ ) L τ i ( f θ i ′ ) \theta \gets \theta -\beta \nabla _{\theta }\sum_{\tau _i\sim p(\tau )}\mathcal{L} _{\tau _i}(f_{\theta _{i}^{'}}) θ←θ−β∇θτi∼p(τ)∑Lτi(fθi′)
算法的为代码如下图所示。
值得注意的是:meta-gradient update更新的是初始值。初始值
θ
\theta
θ会影响训练出来的
θ
′
\theta ^{'}
θ′,求导记作:
∂
L
(
f
θ
′
)
∂
θ
i
=
∑
j
∂
L
(
f
θ
′
)
∂
θ
j
′
∂
θ
j
′
∂
θ
i
\frac{\partial \mathcal{L} (f_{\theta ^{'}})}{\partial \theta _i} =\sum_{j}^{} \frac{\partial \mathcal{L} (f_{\theta ^{'}})}{\partial \theta ^{'}_{j}} \frac{\partial \theta ^{'}_{j}}{\partial \theta _{i}}
∂θi∂L(fθ′)=j∑∂θj′∂L(fθ′)∂θi∂θj′
Meta-gradient带来的一个问题是增加了一次额外的求导,作者在论文中提到使用了一阶近似来表示黑塞矩阵,结果和求两次导数差不多。
3. MAML的物理意义
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jzUEgr85-1679969619219)(figs/maml.png)]
假设初始化参数为
θ
0
\theta ^{0}
θ0,的一个task的training set更新后变成
θ
^
m
\hat{\theta} ^{m}
θ^m,用test set计算第二次梯度会得到另外一个向量,用这个向量的方向来更新
θ
0
\theta ^{0}
θ0得到
θ
1
\theta {1}
θ1。
而模型预训练与它不同,它一次只更新一个梯度,所以计算出来哪个方向就其哪个方向。