这篇文章是看了刘建平老师的LSTM模型与前向反向传播算法后的笔记,同时参考这两篇文章,包括一些公式推导,都是自己的理解,如有错误,欢迎指出。
LSTM前向传播算法
这边直接给出前向传播过程中的公式计算,具体的可以参考上面文章:
其中,
x
(
t
)
−
n
×
1
,
h
(
t
)
−
m
×
1
V
−
l
×
m
,
c
−
l
×
1
,
y
^
(
t
)
−
l
×
1
o
(
t
)
,
f
(
t
)
,
i
(
t
)
,
a
(
t
)
,
C
(
t
)
−
m
×
1
W
f
,
W
i
,
W
a
,
W
o
−
m
×
m
U
f
,
U
i
,
U
a
,
U
o
−
m
×
n
b
f
,
b
i
,
b
a
,
b
o
−
m
×
1
\boldsymbol x^{(t)}-n\times1,\boldsymbol h^{(t)}-m\times1 \\ \boldsymbol V-l\times m,\boldsymbol c-l\times 1,\hat \boldsymbol y^{(t)}-l\times 1 \\ \boldsymbol o^{(t)},\boldsymbol f^{(t)},\boldsymbol i^{(t)},\boldsymbol a^{(t)},\boldsymbol C^{(t)}-m\times 1\\ \boldsymbol W_f,\boldsymbol W_i,\boldsymbol W_a,\boldsymbol W_o-m\times m\\ \boldsymbol U_f,\boldsymbol U_i,\boldsymbol U_a,\boldsymbol U_o-m\times n\\ \boldsymbol b_f,\boldsymbol b_i,\boldsymbol b_a,\boldsymbol b_o-m\times 1
x(t)−n×1,h(t)−m×1V−l×m,c−l×1,y^(t)−l×1o(t),f(t),i(t),a(t),C(t)−m×1Wf,Wi,Wa,Wo−m×mUf,Ui,Ua,Uo−m×nbf,bi,ba,bo−m×1
L
(
t
)
L(t)
L(t)分为两部分,一部分为
l
(
t
)
l(t)
l(t),另一部分为
t
t
t时刻之后的损失
L
(
t
+
1
)
L(t+1)
L(t+1):
L
(
t
)
=
{
l
(
t
)
+
L
(
t
+
1
)
,
t
<
τ
l
(
t
)
,
t
=
τ
L(t) = \begin{cases} l(t)+L(t+1),\quad t<\tau\\ l(t),\quad\quad\quad\quad\quad\ \ \ t=\tau \end{cases}
L(t)={l(t)+L(t+1),t<τl(t), t=τ
所以,当
t
=
τ
t=\tau
t=τ时,我们有
δ
h
(
τ
)
=
∂
L
∂
h
(
τ
)
=
V
T
(
y
^
(
τ
)
−
y
(
τ
)
)
\boldsymbol\delta_h^{(\tau)}=\frac{\partial L}{\partial \boldsymbol h^{(\tau)}}=\boldsymbol V^T(\hat\boldsymbol y^{(\tau)}-\boldsymbol y^{(\tau)})
δh(τ)=∂h(τ)∂L=VT(y^(τ)−y(τ)),具体的推导可以参考RNN,此外,
d
L
=
(
∂
L
∂
h
(
τ
)
)
T
d
h
(
τ
)
=
(
δ
h
(
τ
)
)
T
d
(
o
(
τ
)
⊙
t
a
n
h
(
C
(
τ
)
)
)
=
t
r
(
(
δ
h
(
τ
)
)
T
(
o
(
τ
)
⊙
d
t
a
n
h
(
C
(
τ
)
)
)
)
=
t
r
(
(
δ
h
(
τ
)
⊙
o
(
τ
)
)
T
d
t
a
n
h
(
C
(
τ
)
)
)
=
t
r
(
(
δ
h
(
τ
)
⊙
o
(
τ
)
)
T
(
1
−
t
a
n
h
2
(
C
(
τ
)
)
)
⊙
d
C
(
τ
)
)
=
t
r
(
[
δ
h
(
τ
)
⊙
o
(
τ
)
⊙
(
1
−
t
a
n
h
2
(
C
(
τ
)
)
)
]
T
d
C
(
τ
)
)
\begin{aligned} dL&=\left(\frac{\partial L}{\partial \boldsymbol h^{(\tau)}}\right)^Td\boldsymbol h^{(\tau)}\\ &=\left(\boldsymbol\delta_h^{(\tau)}\right)^Td\left(\boldsymbol o^{(\tau)}\odot tanh(\boldsymbol C^{(\tau)})\right)\\ &=tr\left(\left(\boldsymbol\delta_h^{(\tau)}\right)^T\left(\boldsymbol o^{(\tau)}\odot dtanh(\boldsymbol C^{(\tau)})\right)\right)\\ &=tr\left(\left(\boldsymbol\delta_h^{(\tau)}\odot \boldsymbol o^{(\tau)}\right)^Tdtanh(\boldsymbol C^{(\tau)})\right)\\ &=tr\left(\left(\boldsymbol\delta_h^{(\tau)}\odot \boldsymbol o^{(\tau)}\right)^T\left(\boldsymbol 1-tanh^2(\boldsymbol C^{(\tau)})\right)\odot d\boldsymbol C^{(\tau)}\right)\\ &=tr\left(\left[\boldsymbol\delta_h^{(\tau)}\odot \boldsymbol o^{(\tau)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol C^{(\tau)})\right)\right]^Td\boldsymbol C^{(\tau)}\right) \end{aligned}
dL=(∂h(τ)∂L)Tdh(τ)=(δh(τ))Td(o(τ)⊙tanh(C(τ)))=tr((δh(τ))T(o(τ)⊙dtanh(C(τ))))=tr((δh(τ)⊙o(τ))Tdtanh(C(τ)))=tr((δh(τ)⊙o(τ))T(1−tanh2(C(τ)))⊙dC(τ))=tr([δh(τ)⊙o(τ)⊙(1−tanh2(C(τ)))]TdC(τ))
所以,
δ
C
(
τ
)
=
∂
L
∂
C
(
τ
)
=
δ
h
(
τ
)
⊙
o
(
τ
)
⊙
(
1
−
t
a
n
h
2
(
C
(
τ
)
)
)
\boldsymbol\delta_C^{(\tau)}=\frac{\partial L}{\partial \boldsymbol C^{(\tau)}}=\boldsymbol\delta_h^{(\tau)}\odot \boldsymbol o^{(\tau)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol C^{(\tau)})\right)
δC(τ)=∂C(τ)∂L=δh(τ)⊙o(τ)⊙(1−tanh2(C(τ)))。
接下来由
t
+
1
t+1
t+1项往前推导:
d
L
=
d
l
(
t
)
+
d
L
(
t
+
1
)
=
(
∂
l
(
t
)
∂
h
(
t
)
)
T
d
h
(
t
)
+
(
∂
L
(
t
+
1
)
∂
h
(
t
+
1
)
)
T
d
h
(
t
+
1
)
=
(
∂
l
(
t
)
∂
h
(
t
)
)
T
d
h
(
t
)
+
(
∂
L
(
t
+
1
)
∂
h
(
t
+
1
)
)
T
(
∂
h
(
t
+
1
)
∂
h
(
t
)
)
T
d
h
(
t
)
所
以
:
δ
h
(
t
)
=
∂
L
∂
h
(
t
)
=
∂
l
(
t
)
∂
h
(
t
)
+
∂
h
(
t
+
1
)
∂
h
(
t
)
δ
h
(
t
+
1
)
其
中
,
∂
h
(
t
+
1
)
∂
h
(
t
)
如
下
得
出
:
d
h
(
t
+
1
)
=
d
o
(
t
+
1
)
⊙
t
a
n
h
(
C
(
t
+
1
)
)
=
t
a
n
h
(
C
(
t
+
1
)
)
⊙
d
o
(
t
+
1
)
+
o
(
t
+
1
)
⊙
d
t
a
n
h
(
C
(
t
+
1
)
)
=
t
a
n
h
(
C
(
t
+
1
)
)
⊙
o
(
t
+
1
)
⊙
(
1
−
o
(
t
+
1
)
)
⊙
d
W
o
h
(
t
)
+
o
(
t
+
1
)
⊙
(
1
−
t
a
n
h
2
(
C
(
t
+
1
)
)
)
⊙
d
C
(
t
+
1
)
=
d
i
a
g
[
t
a
n
h
(
C
(
t
+
1
)
)
⊙
o
(
t
+
1
)
⊙
(
1
−
o
(
t
+
1
)
)
]
W
o
d
h
(
t
)
+
d
i
a
g
[
o
(
t
+
1
)
⊙
(
1
−
t
a
n
h
2
(
C
(
t
+
1
)
)
)
]
d
C
(
t
+
1
)
C
(
t
+
1
)
=
C
(
t
)
⊙
f
(
t
+
1
)
+
i
(
t
+
1
)
⊙
a
(
t
+
1
)
d
C
(
t
+
1
)
=
f
(
t
+
1
)
⊙
d
C
(
t
)
+
C
(
t
)
⊙
d
f
(
t
+
1
)
+
a
(
t
+
1
)
⊙
d
i
(
t
+
1
)
+
i
(
t
+
1
)
⊙
d
a
(
t
+
1
)
=
f
(
t
+
1
)
⊙
d
C
(
t
)
+
C
(
t
)
⊙
f
(
t
+
1
)
⊙
(
1
−
f
(
t
+
1
)
)
⊙
d
W
f
h
(
t
)
+
a
(
t
+
1
)
⊙
i
(
t
+
1
)
⊙
(
1
−
i
(
t
+
1
)
)
⊙
d
W
i
h
(
t
)
+
i
(
t
+
1
)
⊙
(
1
−
t
a
n
h
2
(
a
(
t
+
1
)
)
)
⊙
d
W
a
h
(
t
)
=
f
(
t
+
1
)
⊙
d
C
(
t
)
+
d
i
a
g
[
C
(
t
)
⊙
f
(
t
+
1
)
⊙
(
1
−
f
(
t
+
1
)
)
]
W
f
d
h
(
t
)
+
d
i
a
g
[
a
(
t
+
1
)
⊙
i
(
t
+
1
)
⊙
(
1
−
i
(
t
+
1
)
)
]
W
i
d
h
(
t
)
+
d
i
a
g
[
i
(
t
+
1
)
⊙
(
1
−
t
a
n
h
2
(
a
(
t
+
1
)
)
)
]
W
a
d
h
(
t
)
\begin{aligned} dL&=dl(t)+dL(t+1)\\ &=\left(\frac{\partial l(t)}{\partial \boldsymbol h^{(t)}}\right)^Td\boldsymbol h^{(t)}+\left(\frac{\partial L(t+1)}{\partial \boldsymbol h^{(t+1)}}\right)^Td\boldsymbol h^{(t+1)}\\ &=\left(\frac{\partial l(t)}{\partial \boldsymbol h^{(t)}}\right)^Td\boldsymbol h^{(t)}+\left(\frac{\partial L(t+1)}{\partial \boldsymbol h^{(t+1)}}\right)^T\left(\frac{\partial \boldsymbol h^{(t+1)}}{\partial \boldsymbol h^{(t)}}\right)^Td\boldsymbol h^{(t)}\\ 所以:\\ \boldsymbol \delta_h^{(t)}&=\frac{\partial L}{\partial \boldsymbol h^{(t)}}\\ &=\frac{\partial l(t)}{\partial \boldsymbol h^{(t)}}+\frac{\partial \boldsymbol h^{(t+1)}}{\partial \boldsymbol h^{(t)}}\boldsymbol \delta_h^{(t+1)}\\ 其中,\frac{\partial \boldsymbol h^{(t+1)}}{\partial \boldsymbol h^{(t)}}如下得出:\\ d\boldsymbol h^{(t+1)}&=d\boldsymbol o^{(t+1)}\odot tanh(\boldsymbol C^{(t+1)})\\ &=tanh(\boldsymbol C^{(t+1)})\odot d\boldsymbol o^{(t+1)}+\boldsymbol o^{(t+1)}\odot dtanh(\boldsymbol C^{(t+1)})\\ &=tanh(\boldsymbol C^{(t+1)})\odot \boldsymbol o^{(t+1)}\odot (\boldsymbol 1-\boldsymbol o^{(t+1)})\odot d\boldsymbol W_o\boldsymbol h^{(t)}\\ &+\boldsymbol o^{(t+1)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol C^{(t+1)})\right)\odot d\boldsymbol C^{(t+1)}\\ &=diag\left[tanh(\boldsymbol C^{(t+1)})\odot \boldsymbol o^{(t+1)}\odot (\boldsymbol 1-\boldsymbol o^{(t+1)})\right]\boldsymbol W_od\boldsymbol h^{(t)}\\ &+diag\left[\boldsymbol o^{(t+1)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol C^{(t+1)})\right)\right]d\boldsymbol C^{(t+1)}\\ \boldsymbol C^{(t+1)}&=\boldsymbol C^{(t)}\odot \boldsymbol f^{(t+1)}+\boldsymbol i^{(t+1)}\odot \boldsymbol a^{(t+1)}\\ d\boldsymbol C^{(t+1)}&=\boldsymbol f^{(t+1)}\odot d\boldsymbol C^{(t)}+\boldsymbol C^{(t)}\odot d\boldsymbol f^{(t+1)}+\boldsymbol a^{(t+1)}\odot d\boldsymbol i^{(t+1)}+\boldsymbol i^{(t+1)}\odot d\boldsymbol a^{(t+1)}\\ &=\boldsymbol f^{(t+1)}\odot d\boldsymbol C^{(t)}+\boldsymbol C^{(t)}\odot \boldsymbol f^{(t+1)}\odot \left(\boldsymbol 1-\boldsymbol f^{(t+1)}\right)\odot d\boldsymbol W_f\boldsymbol h^{(t)}\\ &+\boldsymbol a^{(t+1)}\odot \boldsymbol i^{(t+1)}\odot \left(\boldsymbol 1-\boldsymbol i^{(t+1)}\right)\odot d\boldsymbol W_i\boldsymbol h^{(t)}\\ &+\boldsymbol i^{(t+1)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol a^{(t+1)})\right)\odot d\boldsymbol W_a\boldsymbol h^{(t)}\\ &=\boldsymbol f^{(t+1)}\odot d\boldsymbol C^{(t)}+diag\left[\boldsymbol C^{(t)}\odot \boldsymbol f^{(t+1)}\odot \left(\boldsymbol 1-\boldsymbol f^{(t+1)}\right)\right]\boldsymbol W_fd\boldsymbol h^{(t)}\\ &+diag\left[\boldsymbol a^{(t+1)}\odot \boldsymbol i^{(t+1)}\odot \left(\boldsymbol 1-\boldsymbol i^{(t+1)}\right)\right]\boldsymbol W_id\boldsymbol h^{(t)}\\ &+diag\left[\boldsymbol i^{(t+1)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol a^{(t+1)})\right)\right]\boldsymbol W_ad\boldsymbol h^{(t)} \end{aligned}
dL所以:δh(t)其中,∂h(t)∂h(t+1)如下得出:dh(t+1)C(t+1)dC(t+1)=dl(t)+dL(t+1)=(∂h(t)∂l(t))Tdh(t)+(∂h(t+1)∂L(t+1))Tdh(t+1)=(∂h(t)∂l(t))Tdh(t)+(∂h(t+1)∂L(t+1))T(∂h(t)∂h(t+1))Tdh(t)=∂h(t)∂L=∂h(t)∂l(t)+∂h(t)∂h(t+1)δh(t+1)=do(t+1)⊙tanh(C(t+1))=tanh(C(t+1))⊙do(t+1)+o(t+1)⊙dtanh(C(t+1))=tanh(C(t+1))⊙o(t+1)⊙(1−o(t+1))⊙dWoh(t)+o(t+1)⊙(1−tanh2(C(t+1)))⊙dC(t+1)=diag[tanh(C(t+1))⊙o(t+1)⊙(1−o(t+1))]Wodh(t)+diag[o(t+1)⊙(1−tanh2(C(t+1)))]dC(t+1)=C(t)⊙f(t+1)+i(t+1)⊙a(t+1)=f(t+1)⊙dC(t)+C(t)⊙df(t+1)+a(t+1)⊙di(t+1)+i(t+1)⊙da(t+1)=f(t+1)⊙dC(t)+C(t)⊙f(t+1)⊙(1−f(t+1))⊙dWfh(t)+a(t+1)⊙i(t+1)⊙(1−i(t+1))⊙dWih(t)+i(t+1)⊙(1−tanh2(a(t+1)))⊙dWah(t)=f(t+1)⊙dC(t)+diag[C(t)⊙f(t+1)⊙(1−f(t+1))]Wfdh(t)+diag[a(t+1)⊙i(t+1)⊙(1−i(t+1))]Widh(t)+diag[i(t+1)⊙(1−tanh2(a(t+1)))]Wadh(t)
将
d
C
(
t
+
1
)
d\boldsymbol C^{(t+1)}
dC(t+1)代入
d
h
(
t
+
1
)
d\boldsymbol h^{(t+1)}
dh(t+1)得到一个很庞大的式子,你们可以自己计算,这边直接给出答案:
∂
h
(
t
+
1
)
∂
h
(
t
)
=
(
W
o
)
T
d
i
a
g
[
t
a
n
h
(
C
(
t
+
1
)
)
⊙
o
(
t
+
1
)
⊙
(
1
−
o
(
t
+
1
)
)
]
+
(
W
f
)
T
d
i
a
g
[
C
(
t
)
⊙
f
(
t
+
1
)
⊙
(
1
−
f
(
t
+
1
)
)
⊙
Δ
C
]
+
(
W
i
)
T
d
i
a
g
[
a
(
t
+
1
)
⊙
i
(
t
+
1
)
⊙
(
1
−
i
(
t
+
1
)
)
⊙
Δ
C
]
+
(
W
a
)
T
d
i
a
g
[
i
(
t
+
1
)
⊙
(
1
−
t
a
n
h
2
(
a
(
t
+
1
)
)
)
⊙
Δ
C
]
其
中
,
Δ
C
=
o
(
t
+
1
)
⊙
(
1
−
t
a
n
h
2
(
C
(
t
+
1
)
)
)
\begin{aligned} \frac{\partial \boldsymbol h^{(t+1)}}{\partial \boldsymbol h^{(t)}}&=\left(\boldsymbol W_o\right)^Tdiag\left[tanh(\boldsymbol C^{(t+1)})\odot \boldsymbol o^{(t+1)}\odot (\boldsymbol 1-\boldsymbol o^{(t+1)})\right]+\left(\boldsymbol W_f\right)^Tdiag\left[\boldsymbol C^{(t)}\odot \boldsymbol f^{(t+1)}\odot \left(\boldsymbol 1-\boldsymbol f^{(t+1)}\right)\odot \Delta\boldsymbol C\right]\\ &+\left(\boldsymbol W_i\right)^Tdiag\left[\boldsymbol a^{(t+1)}\odot \boldsymbol i^{(t+1)}\odot \left(\boldsymbol 1-\boldsymbol i^{(t+1)}\right)\odot \Delta\boldsymbol C\right]+\left(\boldsymbol W_a\right)^Tdiag\left[\boldsymbol i^{(t+1)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol a^{(t+1)})\right)\odot \Delta\boldsymbol C\right]\\ 其中,\Delta\boldsymbol C&=\boldsymbol o^{(t+1)}\odot \left(\boldsymbol 1-tanh^2(\boldsymbol C^{(t+1)})\right) \end{aligned}
∂h(t)∂h(t+1)其中,ΔC=(Wo)Tdiag[tanh(C(t+1))⊙o(t+1)⊙(1−o(t+1))]+(Wf)Tdiag[C(t)⊙f(t+1)⊙(1−f(t+1))⊙ΔC]+(Wi)Tdiag[a(t+1)⊙i(t+1)⊙(1−i(t+1))⊙ΔC]+(Wa)Tdiag[i(t+1)⊙(1−tanh2(a(t+1)))⊙ΔC]=o(t+1)⊙(1−tanh2(C(t+1)))
有了
δ
h
(
t
)
\boldsymbol \delta_h^{(t)}
δh(t),
δ
C
(
t
)
\boldsymbol \delta_C^{(t)}
δC(t)就很容易得出来了:
有了 δ h ( t ) , δ C ( t ) \boldsymbol \delta_h^{(t)},\boldsymbol \delta_C^{(t)} δh(t),δC(t),其他一些参数的梯度就很容易得出来了。