优化器元网络推导
在梯度下降算法中,我们通过下面的公式来调整参数:
θ
t
=
θ
t
−
1
−
α
t
∇
θ
t
−
1
L
t
\boldsymbol{\theta}_{t} = \boldsymbol{\theta}_{t-1} - \alpha_{t} \nabla_{\boldsymbol{\theta}_{t-1}} \mathcal{L}_{t}
θt=θt−1−αt∇θt−1Lt
根据上一节长短时记忆网络的讨论,我们更新Cell记忆时的公式为:
C
t
=
f
t
⊗
C
t
−
1
+
i
t
⊗
C
~
t
\boldsymbol{C}_{t}=\boldsymbol{f}_{t} \otimes \boldsymbol{C}_{t-1} + \boldsymbol{i}_{t} \otimes \tilde{\boldsymbol{C}}_{t}
Ct=ft⊗Ct−1+it⊗C~t
我们可以做如下假设:
f
t
=
1
C
t
−
1
=
θ
t
−
1
i
t
=
α
t
C
t
~
=
∇
θ
t
−
1
L
t
\boldsymbol{f}_{t}=1 \\ \boldsymbol{C}_{t-1} = \boldsymbol{\theta}_{t-1} \\ \boldsymbol{i}_{t} = \alpha_{t} \\ \tilde{\boldsymbol{C}_{t}} = \nabla_{\boldsymbol{\theta}_{t-1}} \mathcal{L}_{t}
ft=1Ct−1=θt−1it=αtCt~=∇θt−1Lt
当我们做上述假设后,梯度下降算法就可以视为长短时记忆网络(LSTM)的Cell更新过程。如果我们把原来做图像分类的网络作为基础网络,用长短时记忆网络(LSTM)做优化器的元网络,LSTM网络中的Cell的状态值即为基础网络的参数,这样就构成了优化器元网络。
我们首先列出长短时记忆网络的公式:
f
t
=
σ
(
W
f
⋅
[
h
t
−
1
,
x
t
]
+
b
f
)
i
t
=
σ
(
W
i
⋅
[
h
t
−
1
,
x
t
]
+
b
i
)
C
~
t
=
t
a
n
h
(
W
C
⋅
[
h
t
−
1
,
x
t
]
+
b
C
)
C
t
=
f
t
⊗
C
t
−
1
+
i
t
⊗
C
~
t
o
t
=
σ
(
W
o
⋅
[
h
t
−
1
,
x
t
]
+
b
o
)
h
t
=
o
t
⊗
t
a
n
h
(
C
t
)
\boldsymbol{f}_{t}=\sigma( W_{f} \cdot [\boldsymbol{h}_{t-1}, \boldsymbol{x}_{t}] + \boldsymbol{b}_{f} ) \\ \boldsymbol{i}_{t}=\sigma( W_{i} \cdot [\boldsymbol{h}_{t-1}, \boldsymbol{x}_{t}] + \boldsymbol{b}_{i} ) \\ \tilde{\boldsymbol{C}}_{t}=tanh( W_{C} \cdot [\boldsymbol{h}_{t-1}, \boldsymbol{x}_{t}] + \boldsymbol{b}_{C} ) \\ \boldsymbol{C}_{t}=\boldsymbol{f}_{t} \otimes \boldsymbol{C}_{t-1} + \boldsymbol{i}_{t} \otimes \tilde{\boldsymbol{C}}_{t} \\ \boldsymbol{o}_{t}=\sigma( W_{o} \cdot [\boldsymbol{h}_{t-1}, \boldsymbol{x}_{t}] + \boldsymbol{b}_{o} ) \\ \boldsymbol{h}_{t}=\boldsymbol{o}_{t} \otimes tanh(\boldsymbol{C}_{t})
ft=σ(Wf⋅[ht−1,xt]+bf)it=σ(Wi⋅[ht−1,xt]+bi)C~t=tanh(WC⋅[ht−1,xt]+bC)Ct=ft⊗Ct−1+it⊗C~tot=σ(Wo⋅[ht−1,xt]+bo)ht=ot⊗tanh(Ct)
我们首先来看遗忘门的应用,我们在训练深度学习网络中,最难处理的一种情况是我们进入了一个平坦的区域,这时梯度基本为零,对权值的调整将非常非常小,无法进行学习。如果我们利用元网络中的遗忘门,我们可以通过减小参数值,忘记一些之前的记忆来实现从这个平坦的区域内跳出来。此时遗忘门的输入值为:前一时刻的参数值、当前时刻基础网络的代价函数值、当前时刻基础网络的代价函数对网络参数的微分值、前一时刻遗忘门的输出,如下所示:
f
t
=
σ
(
W
f
⋅
[
θ
t
−
1
,
L
t
,
∇
θ
t
−
1
,
f
t
−
1
]
+
b
f
)
\boldsymbol{f}_{t}=\sigma( W_{f} \cdot [\boldsymbol{\theta}_{t-1}, \mathcal{L}_{t},\nabla_{\boldsymbol{\theta}_{t-1}}, \boldsymbol{f}_{t-1}] + \boldsymbol{b}_{f} )
ft=σ(Wf⋅[θt−1,Lt,∇θt−1,ft−1]+bf)
我们用输入门来控制需要更新哪些参数:
i
t
=
σ
(
W
i
⋅
[
θ
t
−
1
,
L
t
,
∇
θ
t
−
1
,
i
t
−
1
]
+
b
i
)
\boldsymbol{i}_{t}=\sigma( W_{i} \cdot [\boldsymbol{\theta}_{t-1}, \mathcal{L}_{t},\nabla_{\boldsymbol{\theta}_{t-1}}, \boldsymbol{i}_{t-1}] + \boldsymbol{b}_{i} )
it=σ(Wi⋅[θt−1,Lt,∇θt−1,it−1]+bi)
优化器学习算法
优化器元学习
利用随机值初始化元网络参数
ϕ
0
\boldsymbol{\phi}_{0}
ϕ0
For d=1…N iterations
\quad
从数据集
D
D
D中随机采样
D
t
r
a
i
n
D^{train}
Dtrain和
D
t
e
s
t
D^{test}
Dtest
\quad
将元网络中Cell初始状态
C
0
\boldsymbol{C}_{0}
C0赋给基础网络参数
θ
0
\boldsymbol{\theta}_{0}
θ0
\quad
For
t
=
1...
T
t=1...T
t=1...T iterations
\quad
\quad
从
D
t
r
a
i
n
D^{train}
Dtrain中随机抽样一个批次
X
t
,
Y
t
X_{t}, Y_{t}
Xt,Yt
\quad
\quad
在基础网络上计算代价函数值:
L
t
(
Y
t
∣
X
t
;
θ
t
)
\mathcal{L}_{t}(Y_{t} \vert X_{t}; \boldsymbol{\theta}_{t})
Lt(Yt∣Xt;θt)
\quad
\quad
遗忘门:
f
t
=
σ
(
W
f
⋅
[
θ
t
−
1
,
L
t
,
∇
θ
t
−
1
,
f
t
−
1
]
+
b
f
)
\boldsymbol{f}_{t}=\sigma( W_{f} \cdot [\boldsymbol{\theta}_{t-1}, \mathcal{L}_{t},\nabla_{\boldsymbol{\theta}_{t-1}}, \boldsymbol{f}_{t-1}] + \boldsymbol{b}_{f} )
ft=σ(Wf⋅[θt−1,Lt,∇θt−1,ft−1]+bf)
\quad
\quad
输入门:
i
t
=
σ
(
W
i
⋅
[
θ
t
−
1
,
L
t
,
∇
θ
t
−
1
,
i
t
−
1
]
+
b
i
)
\boldsymbol{i}_{t}=\sigma( W_{i} \cdot [\boldsymbol{\theta}_{t-1}, \mathcal{L}_{t},\nabla_{\boldsymbol{\theta}_{t-1}}, \boldsymbol{i}_{t-1}] + \boldsymbol{b}_{i} )
it=σ(Wi⋅[θt−1,Lt,∇θt−1,it−1]+bi)
\quad
\quad
输入信号预处理:
C
~
t
=
t
a
n
h
(
W
C
⋅
[
h
t
−
1
,
x
t
]
+
b
C
)
\tilde{\boldsymbol{C}}_{t}=tanh( W_{C} \cdot [\boldsymbol{h}_{t-1}, \boldsymbol{x}_{t}] + \boldsymbol{b}_{C} )
C~t=tanh(WC⋅[ht−1,xt]+bC)
\quad
\quad
更新Cell状态:
C
t
=
f
t
⊗
C
t
−
1
+
i
t
⊗
C
~
t
\boldsymbol{C}_{t}=\boldsymbol{f}_{t} \otimes \boldsymbol{C}_{t-1} + \boldsymbol{i}_{t} \otimes \tilde{\boldsymbol{C}}_{t}
Ct=ft⊗Ct−1+it⊗C~t
\quad
\quad
输出门:
o
t
=
σ
(
W
o
⋅
[
h
t
−
1
,
x
t
]
+
b
o
)
\boldsymbol{o}_{t}=\sigma( W_{o} \cdot [\boldsymbol{h}_{t-1}, \boldsymbol{x}_{t}] + \boldsymbol{b}_{o} )
ot=σ(Wo⋅[ht−1,xt]+bo)
\quad
\quad
隐藏层状态:
h
t
=
o
t
⊗
t
a
n
h
(
C
t
)
\boldsymbol{h}_{t}=\boldsymbol{o}_{t} \otimes tanh(\boldsymbol{C}_{t})
ht=ot⊗tanh(Ct)
\quad
\quad
将元网络的Cell状态
C
t
\boldsymbol{C}_{t}
Ct赋引基础网络参数
θ
∣
t
\boldsymbol{\theta|_{t}}
θ∣t
\quad
EndFor
\quad
从测试集
D
t
e
s
t
D^{test}
Dtest中抽样出一个样本
X
,
Y
X, Y
X,Y
\quad
计算基础网络代价函数:
L
t
e
s
t
=
L
θ
t
(
Y
∣
X
;
θ
t
)
\mathcal{L}^{test}=\mathcal{L}_{\boldsymbol{\theta}_{t}}(Y \vert X; \boldsymbol{\theta}_{t})
Ltest=Lθt(Y∣X;θt)
\quad
利用
∇
L
t
e
s
t
θ
t
−
1
\nabla{\mathcal{L}^{test}}_{\boldsymbol{\theta}_{t-1}}
∇Ltestθt−1更新元网络参数
ϕ
d
\phi_{d}
ϕd
EndFor
在本章中我们讲述了元学习的基本概念,同时以优化器元学习网络为例,详细讲解了元学习算法的基本数学原理。在下一章中,我们将讨论最先出现同时也是使用最广泛的一种元学习网络Siamese网络,并通过TensorFlow2.0来实现对MNIST手写数字数据集的处理。