Meta-learning with implicit gradients--nips19
论文思想
原始的MAML算法一个很大的挑战是外循环(元更新)需要通过对内循环(梯度自适应)过程进行求导,一般就要求存储和计算高阶导数。这篇论文的核心是利用隐微分方法,求解过程只需要内循环优化的解,而不需要整个内循环优化器的优化过程。
好处:①这样就将元梯度计算(外循环)和内循环优化器的选择解耦,可以任意选择内层优化器;②多步梯度不再有梯度消失或者存储约束
上图可知,MAML算法需要对内循环优化路径进行求导来计算元梯度,一阶MAML简单的将
d
ϕ
i
d
θ
\frac{d\phi_i}{d\theta}
dθdϕi置为
I
I
I来进行估计;iMAML通过估计local curvature推导出准确的元梯度解析表达式(用内循环的solution而不是对solution的求导来表达元梯度),而不用对整个优化路径进行求导。
这样的好处有:不用存储和求导优化路径,能有效地在内循环中应用多步梯度;整个方法与内优化方法的选择无关,只要能得到内循环优化问题的一个估计解就行。这样可以应用高阶方法甚至不可导的优化方法。
Few-shot case formula
θ M L ∗ : = argmin θ ∈ Θ F ( θ ) ⏞ outer-lever , where F ( θ ) = 1 M ∑ i = 1 M L ( A l g ( θ , D i t r ) ⏞ inner-level , D i test ) \overbrace{\boldsymbol{\theta}_{\mathrm{ML}}^{*}:=\underset{\boldsymbol{\theta} \in \Theta}{\operatorname{argmin}} F(\boldsymbol{\theta})}^{\text{outer-lever}}, \text { where } F(\boldsymbol{\theta})=\frac{1}{M} \sum_{i=1}^{M} \mathcal{L}\left(\overbrace{\mathcal{A} l g\left(\boldsymbol{\theta}, \mathcal{D}_{i}^{\mathrm{tr}}\right)}^{\text {inner-level }}, \mathcal{D}_{i}^{\text {test }}\right) θML∗:=θ∈ΘargminF(θ) outer-lever, where F(θ)=M1i=1∑ML⎝⎛Alg(θ,Ditr) inner-level ,Ditest ⎠⎞
公式中
A
l
g
\mathcal{A} l g
Alg代表内循环的算法,输出的是自适应任务的优化参数。为了防止过拟合,可以在内循环过程中加入正则项:
A
l
g
⋆
(
θ
,
D
i
t
r
)
=
arg
min
ϕ
′
∈
Φ
L
(
ϕ
′
,
D
i
t
r
)
+
λ
2
∣
∣
ϕ
′
−
θ
∣
∣
2
\mathcal{A} l g^\star\left(\boldsymbol{\theta}, \mathcal{D}_{i}^{\mathrm{tr}}\right)=\arg\min_{\phi'\in\Phi}\mathcal{L}(\phi',\mathcal{D}^{tr}_{i})+\frac{\lambda}{2}||\phi'-\theta||^2
Alg⋆(θ,Ditr)=argϕ′∈ΦminL(ϕ′,Ditr)+2λ∣∣ϕ′−θ∣∣2
这里
θ
\theta
θ是我们要求的元参数(即模型初始化),内循环过程中看做一个常量,在外循环中梯度更新求解,内循环过程实际变量是自适应参数
ϕ
′
\phi'
ϕ′。
⋆
\star
⋆表示可准确求解,实际当中使用梯度迭代法只能返回估计的最优值。进一步的双阶段优化问题可改写为:
θ
M
L
∗
:
=
argmin
θ
∈
Θ
F
(
θ
)
,
where
F
(
θ
)
=
1
M
∑
i
=
1
M
L
i
(
A
l
g
i
⋆
(
θ
)
)
,
and
A
l
g
i
⋆
(
θ
)
:
=
argmin
ϕ
′
∈
Φ
G
i
(
ϕ
′
,
θ
)
,
where
G
i
(
ϕ
′
,
θ
)
=
L
^
i
(
ϕ
′
)
+
λ
2
∥
ϕ
′
−
θ
∥
2
\begin{array}{l}{\boldsymbol{\theta}_{\mathrm{ML}}^{*}:=\underset{\boldsymbol{\theta} \in \Theta}{\operatorname{argmin}} F(\boldsymbol{\theta}), \text { where } F(\boldsymbol{\theta})=\frac{1}{M} \sum_{i=1}^{M} \mathcal{L}_{i}\left(\mathcal{A} l g_{i}^{\star}(\boldsymbol{\theta})\right), \text { and }} \\ {\mathcal{A} l g_{i}^{\star}(\boldsymbol{\theta}):=\underset{\boldsymbol{\phi}^{\prime} \in \Phi}{\operatorname{argmin}} G_{i}\left(\boldsymbol{\phi}^{\prime}, \boldsymbol{\theta}\right), \text { where } G_{i}\left(\boldsymbol{\phi}^{\prime}, \boldsymbol{\theta}\right)=\hat{\mathcal{L}}_{i}\left(\boldsymbol{\phi}^{\prime}\right)+\frac{\lambda}{2}\left\|\boldsymbol{\phi}^{\prime}-\boldsymbol{\theta}\right\|^{2}}\end{array}
θML∗:=θ∈ΘargminF(θ), where F(θ)=M1∑i=1MLi(Algi⋆(θ)), and Algi⋆(θ):=ϕ′∈ΦargminGi(ϕ′,θ), where Gi(ϕ′,θ)=L^i(ϕ′)+2λ∥∥ϕ′−θ∥∥2其中
L
i
(
ϕ
)
:
=
L
(
ϕ
,
D
i
test
)
,
L
^
i
(
ϕ
)
:
=
L
(
ϕ
,
D
i
tr
)
,
A
l
g
i
(
θ
)
:
=
A
l
g
(
θ
,
D
i
tr
)
\mathcal{L}_{i}(\phi):=\mathcal{L}\left(\phi, \mathcal{D}_{i}^{\text {test }}\right), \quad \hat{\mathcal{L}}_{i}(\phi):=\mathcal{L}\left(\phi, \mathcal{D}_{i}^{\text {tr }}\right), \quad \mathcal{A} l g_{i}(\boldsymbol{\theta}):=\mathcal{A} l g\left(\boldsymbol{\theta}, \mathcal{D}_{i}^{\text {tr }}\right)
Li(ϕ):=L(ϕ,Ditest ),L^i(ϕ):=L(ϕ,Ditr ),Algi(θ):=Alg(θ,Ditr )用
d
,
∇
d,\nabla
d,∇分别表示全导数和偏导数,根据链式法则,我们知道元梯度可写为:
d
θ
L
i
(
A
l
g
i
(
θ
)
)
=
d
A
l
g
i
(
θ
)
d
θ
∇
ϕ
L
i
(
ϕ
)
∣
ϕ
=
A
l
g
i
(
θ
)
=
d
A
l
g
i
(
θ
)
d
θ
∇
ϕ
L
i
(
A
l
g
i
(
θ
)
)
d_{\boldsymbol{\theta}}\mathcal{L}_i(\mathcal{A} l g_{i}(\boldsymbol{\theta}))=\frac{d\mathcal{A} l g_{i}(\boldsymbol{\theta})}{d\boldsymbol{\theta}}\nabla_\phi\mathcal{L}_i(\phi)|_{\phi=\mathcal{A} l g_{i}(\boldsymbol{\theta})}=\frac{d\mathcal{A} l g_{i}(\boldsymbol{\theta})}{d\boldsymbol{\theta}}\nabla_\phi\mathcal{L}_i(\mathcal{A} l g_{i}(\boldsymbol{\theta}))
dθLi(Algi(θ))=dθdAlgi(θ)∇ϕLi(ϕ)∣ϕ=Algi(θ)=dθdAlgi(θ)∇ϕLi(Algi(θ))
Implicit MAML Algorithm
上式中
∇
ϕ
L
i
(
A
l
g
i
(
θ
)
)
=
∇
ϕ
L
i
(
ϕ
)
∣
ϕ
=
A
l
g
i
(
θ
)
\nabla_\phi\mathcal{L}_i(\mathcal{A} l g_{i}(\boldsymbol{\theta}))=\nabla_\phi\mathcal{L}_i(\phi)|_{\phi=\mathcal{A} l g_{i}(\boldsymbol{\theta})}
∇ϕLi(Algi(θ))=∇ϕLi(ϕ)∣ϕ=Algi(θ)在求解出
A
l
g
i
⋆
(
θ
)
\mathcal{A} l g^\star_{i}(\boldsymbol{\theta})
Algi⋆(θ)(利用梯度下降或其他优化方法)后,很容易计算。而
d
A
l
g
i
(
θ
)
d
θ
\frac{d\mathcal{A} l g_{i}(\boldsymbol{\theta})}{d\boldsymbol{\theta}}
dθdAlgi(θ)的计算比较复杂,直接利用导数传递涉及到高阶导数,且需要记录整个更新过程。将内循环(自适应)过程的结果
ϕ
i
=
A
l
g
i
⋆
\phi_i = \mathcal{A} l g^\star_{i}
ϕi=Algi⋆隐式地定义为优化问题的solution。那么可以采用一种不需要考虑优化路径的方法来计算
ϕ
i
\phi_i
ϕi(Lemma 1):
d
A
l
g
i
⋆
(
θ
)
d
θ
=
(
I
+
1
λ
∇
ϕ
2
L
^
i
(
ϕ
i
)
)
−
1
\frac{d\mathcal{A} l g^\star_{i}(\boldsymbol{\theta})}{d\boldsymbol{\theta}}=\left(\boldsymbol{I}+\frac{1}{\lambda}\nabla^2_\phi\hat{\mathcal{L}}_i(\phi_i)\right)^{-1}
dθdAlgi⋆(θ)=(I+λ1∇ϕ2L^i(ϕi))−1
证明:
ϕ
i
=
A
l
g
i
⋆
\phi_i = \mathcal{A} l g^\star_{i}
ϕi=Algi⋆是函数
G
i
(
ϕ
′
,
θ
)
=
L
^
i
(
ϕ
′
)
+
λ
2
∥
ϕ
′
−
θ
∥
2
G_{i}\left(\boldsymbol{\phi}^{\prime}, \boldsymbol{\theta}\right)=\hat{\mathcal{L}}_{i}\left(\boldsymbol{\phi}^{\prime}\right)+\frac{\lambda}{2}\left\|\boldsymbol{\phi}^{\prime}-\boldsymbol{\theta}\right\|^{2}
Gi(ϕ′,θ)=L^i(ϕ′)+2λ∥∥ϕ′−θ∥∥2的最小值的时候满足一阶必要条件,即一阶梯度为0:
∇
ϕ
′
G
(
ϕ
′
,
θ
)
∣
ϕ
′
=
ϕ
i
=
0
⟹
∇
L
^
(
ϕ
i
)
+
λ
(
ϕ
i
−
θ
)
=
0
⟹
ϕ
i
=
θ
−
1
λ
∇
L
^
(
ϕ
i
)
\left.\nabla_{\boldsymbol{\phi}^{\prime}} G\left(\boldsymbol{\phi}^{\prime}, \boldsymbol{\theta}\right)\right|_{\boldsymbol{\phi}^{\prime}=\boldsymbol{\phi}_i}=0 \Longrightarrow \nabla \hat{\mathcal{L}}(\boldsymbol{\phi}_i)+\lambda(\boldsymbol{\phi}_i-\boldsymbol{\theta})=0 \Longrightarrow \boldsymbol{\phi}_i=\boldsymbol{\theta}-\frac{1}{\lambda} \nabla \hat{\mathcal{L}}(\boldsymbol{\phi}_i)
∇ϕ′G(ϕ′,θ)∣∣ϕ′=ϕi=0⟹∇L^(ϕi)+λ(ϕi−θ)=0⟹ϕi=θ−λ1∇L^(ϕi)上式是常见的隐等式,当倒数存在的时候,上式左右两边同时对
θ
\boldsymbol{\theta}
θ求导有:
d
ϕ
i
d
θ
=
I
−
1
λ
∇
2
L
^
(
ϕ
i
)
d
ϕ
i
d
θ
⟹
(
I
+
1
λ
∇
2
L
^
(
ϕ
i
)
)
d
ϕ
i
d
θ
=
I
\frac{d \boldsymbol{\phi}_i}{d \boldsymbol{\theta}}=I-\frac{1}{\lambda} \nabla^{2} \hat{\mathcal{L}}(\boldsymbol{\phi}_i) \frac{d \boldsymbol{\phi}_i}{d \boldsymbol{\theta}} \Longrightarrow\left(I+\frac{1}{\lambda} \nabla^{2} \hat{\mathcal{L}}(\boldsymbol{\phi_i})\right) \frac{d \boldsymbol{\phi_i}}{d \boldsymbol{\theta}}=I
dθdϕi=I−λ1∇2L^(ϕi)dθdϕi⟹(I+λ1∇2L^(ϕi))dθdϕi=I
Practical Algorithm
上式中
d
ϕ
i
d
θ
=
d
A
l
g
i
⋆
(
θ
)
d
θ
\frac{d \boldsymbol{\phi}_i}{d \boldsymbol{\theta}}=\frac{d\mathcal{A} l g^\star_{i}(\boldsymbol{\theta})}{d\boldsymbol{\theta}}
dθdϕi=dθdAlgi⋆(θ)在计算中有两个困难,首先
A
l
g
i
⋆
\mathcal{A} l g^\star_{i}
Algi⋆是精确的解,而通过内循环优化得到的往往只是估计解;除此之外,计算还涉及到求逆和二阶导,这对深度神经网络是很难的。本文采取估计的方法对上式求解进行简化,核心公式为:
∥
g
i
−
(
I
+
1
λ
∇
ϕ
2
L
^
i
(
ϕ
i
)
)
−
1
∇
ϕ
L
i
(
ϕ
i
)
∥
≤
δ
′
\left\|\boldsymbol{g}_{i}-\left(I+\frac{1}{\lambda} \nabla_{\boldsymbol{\phi}}^{2} \hat{\mathcal{L}}_{i}\left(\boldsymbol{\phi}_{i}\right)\right)^{-1} \nabla_{\boldsymbol{\phi}} \mathcal{L}_{i}\left(\boldsymbol{\phi}_{i}\right)\right\| \leq \delta^{\prime}
∥∥∥∥∥gi−(I+λ1∇ϕ2L^i(ϕi))−1∇ϕLi(ϕi)∥∥∥∥∥≤δ′式中
g
i
\boldsymbol{g}_i
gi即为对元梯度
d
θ
L
i
(
A
l
g
i
(
θ
)
)
d_{\boldsymbol{\theta}}\mathcal{L}_i(\mathcal{A} l g_{i}(\boldsymbol{\theta}))
dθLi(Algi(θ))的估计,
ϕ
i
\boldsymbol{\phi}_i
ϕi是对最优值
A
l
g
i
⋆
\mathcal{A} l g^\star_{i}
Algi⋆的估计,利用梯度优化迭代法什么的求解。那么进一步的上述
g
i
\boldsymbol{g}_i
gi的求解可转化成一个二次型优化问题:
min
w
1
2
w
⊤
(
I
+
1
λ
∇
ϕ
2
L
^
i
(
ϕ
i
)
)
w
−
w
⊤
∇
ϕ
L
i
(
ϕ
i
)
\min _{\boldsymbol{w}} \frac{1}{2}\boldsymbol{w}^{\top}\left(\boldsymbol{I}+\frac{1}{\lambda} \nabla_{\boldsymbol{\phi}}^{2} \hat{\mathcal{L}}_{i}\left(\boldsymbol{\phi}_{i}\right)\right) \boldsymbol{w}-\boldsymbol{w}^{\top} \nabla_{\boldsymbol{\phi}} \mathcal{L}_{i}\left(\boldsymbol{\phi}_{i}\right)
wmin21w⊤(I+λ1∇ϕ2L^i(ϕi))w−w⊤∇ϕLi(ϕi)这样可以利用共轭梯度法快速求解。过程中只需要计算
∇
2
L
i
^
(
ϕ
i
)
v
\nabla^2\hat{\mathcal{L}_i}(\boldsymbol{\phi}_i)\boldsymbol{v}
∇2Li^(ϕi)v(
v
\boldsymbol{v}
v是共轭梯度)