1.前记
续《LSTM反向传播详解Part1》后续……,关于本篇文章的主题就是得到“模型矩阵参数”,可能继续需要一点数学知识。
在这里推荐另外的一篇文章《机器学习深度学习中反向传播之偏导数链式法则》(本文以《链式法则》简称),自卖自夸一下:文章从微观的单变量元素"一五一十一板一眼的求"(你一定能从原文中找到这句话),到以向量、矩阵角度考虑阐述关于偏导数的链式法则。所以本文的公式将不再推导,而是直接从Part1中
∂
L
/
∂
h
t
,
∂
L
/
∂
c
t
\partial L/\partial h_t,\partial L/\partial c_t
∂L/∂ht,∂L/∂ct得到模型中参数
W
W
W矩阵以及偏置参数
b
b
b。
在写本文的时候已经将LSTM的训练用python实现了,后续会在Part3中结合本文与Part1内容讲解训练的实现。因为在写LSTM系列文章的时候就有点担心,这么多公式,要尽量保证正确还是挺难的,心里会没谱(即使通过代码实现了LSTM的训练,但也不能严格说明这一列文章确定没犯错,或许是凑巧得到了正确答案也不一定)。
2.反向传播part2
在《LSTM反向传播详解Part1》文中,我们得到了以下变量:
∂
L
/
∂
h
τ
,
∂
L
/
∂
c
τ
\partial L/\partial h_{\tau},\partial L/\partial c_{\tau}
∂L/∂hτ,∂L/∂cτ,
∂
L
/
∂
h
t
,
∂
L
/
∂
c
t
\partial L/\partial h_t,\partial L/\partial c_t
∂L/∂ht,∂L/∂ct,现在我们需要得到模型的参数
W
f
,
b
f
,
W
i
,
b
i
,
W
c
,
b
c
,
W
o
,
b
o
W_f,b_f,W_i,b_i,W_c,b_c,W_o,b_o
Wf,bf,Wi,bi,Wc,bc,Wo,bo?
正向模型参数方程:
以对
b
o
b_o
bo矩阵的偏导数为例,先求
o
t
o_{t}
ot的偏导数,根据《链式法则》中公式(4):
∂
L
/
∂
o
t
=
∂
L
/
∂
h
t
⊙
tanh
(
c
t
)
\partial L/\partial o_{t} = \partial L/\partial h_{t}\ \odot \tanh(c_{t})
∂L/∂ot=∂L/∂ht ⊙tanh(ct)
从而
∂
L
/
∂
b
o
=
∂
L
/
∂
h
t
⊙
tanh
(
c
t
)
⊙
o
t
⊙
(
1
−
o
t
)
\partial L/\partial b_{o} = \partial L/\partial h_{t}\ \odot \tanh(c_{t})\ \odot o_t \odot (1-o_t)
∂L/∂bo=∂L/∂ht ⊙tanh(ct) ⊙ot⊙(1−ot)
继续求对
W
o
W_{o}
Wo的偏导数(以
h
x
t
−
1
hx_{t-1}
hxt−1简称
[
h
t
−
1
;
x
t
−
1
]
这
个
列
向
量
[h_{t-1};x_{t-1}]这个列向量
[ht−1;xt−1]这个列向量),根据《链式法则》中公式(2):
∂
L
/
∂
W
o
=
∂
L
/
∂
b
o
⋅
h
x
t
−
1
T
=
(
∂
L
/
∂
h
t
⊙
tanh
(
c
t
)
⊙
o
t
⊙
(
1
−
o
t
)
)
⋅
h
x
t
−
1
T
\partial L/\partial W_{o} = \partial L/\partial b_{o}\cdot hx_{t-1}^T =(\partial L/\partial h_{t}\ \odot \tanh(c_{t})\ \odot o_t \odot (1-o_t))\cdot hx_{t-1}^T
∂L/∂Wo=∂L/∂bo⋅hxt−1T=(∂L/∂ht ⊙tanh(ct) ⊙ot⊙(1−ot))⋅hxt−1T
继续刷公式:
{
∂
L
∂
b
f
=
∂
L
∂
c
t
⊙
c
t
−
1
⊙
f
t
⊙
(
1
−
f
t
)
∂
L
∂
W
f
=
∂
L
∂
b
f
⋅
h
x
t
−
1
T
\left\{ \begin{aligned} \frac {\partial L}{\partial b_f} & =\frac {\partial L}{\partial c_t}\odot c_{t-1}\odot f_t \odot (1-f_t)\\ \frac {\partial L}{\partial W_f} & = \frac {\partial L}{\partial b_f} \cdot hx_{t-1}^T \end{aligned} \right.
⎩⎪⎪⎨⎪⎪⎧∂bf∂L∂Wf∂L=∂ct∂L⊙ct−1⊙ft⊙(1−ft)=∂bf∂L⋅hxt−1T
{
∂
L
∂
b
i
=
∂
L
∂
c
t
⊙
c
t
^
⊙
i
t
⊙
(
1
−
i
t
)
∂
L
∂
W
i
=
∂
L
∂
b
i
⋅
h
x
t
−
1
T
\left\{ \begin{aligned} \frac {\partial L}{\partial b_i} & =\frac {\partial L}{\partial c_t}\odot \hat{c_t}\odot i_t \odot (1-i_t)\\ \frac {\partial L}{\partial W_i} & = \frac {\partial L}{\partial b_i} \cdot hx_{t-1}^T \end{aligned} \right.
⎩⎪⎪⎨⎪⎪⎧∂bi∂L∂Wi∂L=∂ct∂L⊙ct^⊙it⊙(1−it)=∂bi∂L⋅hxt−1T
{
∂
L
∂
b
c
=
∂
L
∂
c
t
⊙
i
t
⊙
(
1
−
c
t
^
⊙
c
t
^
)
∂
L
∂
W
c
=
∂
L
∂
b
c
⋅
h
x
t
−
1
T
\left\{ \begin{aligned} \frac {\partial L}{\partial b_c} & =\frac {\partial L}{\partial c_t}\odot i_t \odot (1-\hat{c_t}\odot\hat{c_t})\\ \frac {\partial L}{\partial W_c} & = \frac {\partial L}{\partial b_c} \cdot hx_{t-1}^T \end{aligned} \right.
⎩⎪⎪⎨⎪⎪⎧∂bc∂L∂Wc∂L=∂ct∂L⊙it⊙(1−ct^⊙ct^)=∂bc∂L⋅hxt−1T
带着维度思考问题,
∂
L
/
∂
h
t
\partial L/\partial h_{t}
∂L/∂ht与Part1《LSTM反向传播详解Part1》中
∂
L
/
∂
h
t
\partial L/\partial h_{t}
∂L/∂ht维度是对不上的,卖个关子,后期会在part3中说明。
3.总结
本文在Part1《LSTM反向传播详解Part1》得到 ∂ L / ∂ h τ , ∂ L / ∂ c τ \partial L/\partial h_{\tau},\partial L/\partial c_{\tau} ∂L/∂hτ,∂L/∂cτ, ∂ L / ∂ h t , ∂ L / ∂ c t \partial L/\partial h_t,\partial L/\partial c_t ∂L/∂ht,∂L/∂ct基础上,进一步获得模型中矩阵参数以及偏置参数 W f , b f , W i , b i , W c , b c , W o , b o W_f,b_f,W_i,b_i,W_c,b_c,W_o,b_o Wf,bf,Wi,bi,Wc,bc,Wo,bo。个人认为虽然本文也有不少的公式(而且中间省略了一大把的中间步骤公式),但我认为更多的只是根据《链式法则》的一个小练手而已。
4.展望
所有的公式都已罗列完毕,那么以下问题:
1.Part1《LSTM反向传播详解Part1》中
∂
L
/
∂
h
t
\partial L/\partial h_{t}
∂L/∂ht维度显然为
[
(
h
_
d
i
m
e
n
s
+
x
_
d
i
m
e
n
s
)
×
1
]
[(h\_dimens+x\_dimens)\times1]
[(h_dimens+x_dimens)×1],本文中应该为
[
h
_
d
i
m
e
n
s
×
1
]
[h\_dimens\times1]
[h_dimens×1],否则
∂
L
/
∂
o
t
=
∂
L
/
∂
h
t
⊙
tanh
(
c
t
)
\partial L/\partial o_{t} = \partial L/\partial h_{t}\ \odot \tanh(c_{t})
∂L/∂ot=∂L/∂ht ⊙tanh(ct)表达式将会出错。这是怎么一回事?
2.本文公式中出现大量
i
t
,
c
t
−
1
,
f
t
,
o
t
,
c
t
^
i_t,c_{t-1},f_t,o_t,\hat{c_t}
it,ct−1,ft,ot,ct^等变量,在程序中要怎么实现?
3.本文的公式都是针对单个样本推导的,当每个batch中有大量样本时,怎么办?不同时刻的
h
t
,
c
t
h_t,c_t
ht,ct的偏导数都有对模型参数的链接,要如何实现?
4.如果理论一切正常,你相信你能得到正确的结果吗?比如参数更新采用梯度下降还是Adam?我该如何“自编”样本数据,Part1文中使用ont-hot分类标签(最后通过
s
o
f
t
m
a
x
softmax
softmax层分类)能得到正确结果吗?
上述问题都将在Part3中做进一步解释。
真理总是越辩越明,如果觉得以上博文有任何问题欢迎留言,或者电邮494431025@qq.com