这是一张经典的LSTM示意图,LSTM依靠
f
t
f_t
ft、
i
t
i_t
it、
o
t
o_t
ot来控制输入输出,
f
t
=
σ
(
W
f
⋅
[
h
t
−
1
,
x
t
]
+
b
f
)
f_{t}=\sigma\left(W_{f} \cdot\left[h_{t-1}, x_{t}\right]+b_{f}\right)
ft=σ(Wf⋅[ht−1,xt]+bf)
i
t
=
σ
(
W
i
⋅
[
h
t
−
1
,
x
t
]
+
b
i
)
i_{t}=\sigma\left(W_{i} \cdot\left[h_{t-1}, x_{t}\right]+b_{i}\right)
it=σ(Wi⋅[ht−1,xt]+bi)
o
t
=
σ
(
W
o
[
h
t
−
1
,
x
t
]
+
b
o
)
o_{t}=\sigma\left(W_{o}\left[h_{t-1}, x_{t}\right]+b_{o}\right)
ot=σ(Wo[ht−1,xt]+bo)
我们将其简化为:
f
t
=
σ
(
W
f
X
t
+
b
f
)
f_{t}=\sigma\left(W_{f} X_{t}+b_{f}\right)
ft=σ(WfXt+bf)
i
t
=
σ
(
W
i
X
t
+
b
i
)
i_{t}=\sigma\left(W_{i} X_{t}+b_{i}\right)
it=σ(WiXt+bi)
o
i
=
σ
(
W
o
X
t
+
b
o
)
o_{i}=\sigma\left(W_{o} X_{t}+b_{o}\right)
oi=σ(WoXt+bo)
当前的状态
S
t
=
f
t
S
t
−
1
+
i
t
X
t
S_{t}=f_{t} S_{t-1}+i_{t} X_{t}
St=ftSt−1+itXt 类似与传统RNN
S
t
=
W
s
S
t
−
1
+
W
x
X
t
+
b
1
S_{t}=W_{s} S_{t-1}+W_{x} X_{t}+b_{1}
St=WsSt−1+WxXt+b1 。将LSTM的状态表达式展开后得:
S
t
=
σ
(
W
f
X
t
+
b
f
)
S
t
−
1
+
σ
(
W
i
X
t
+
b
i
)
X
t
S_{t}=\sigma\left(W_{f} X_{t}+b_{f}\right) S_{t-1}+\sigma\left(W_{i} X_{t}+b_{i}\right) X_{t}
St=σ(WfXt+bf)St−1+σ(WiXt+bi)Xt 如果加上激活函数
S
t
=
tanh
[
σ
(
W
f
X
t
+
b
f
)
S
t
−
1
+
σ
(
W
i
X
t
+
b
i
)
X
t
]
S_{t}=\tanh \left[\sigma\left(W_{f} X_{t}+b_{f}\right) S_{t-1}+\sigma\left(W_{i} X_{t}+b_{i}\right) X_{t}\right]
St=tanh[σ(WfXt+bf)St−1+σ(WiXt+bi)Xt] RNN梯度消失和爆炸的原因这篇文章中传统RNN求偏导的过程包含:
∏
j
=
k
+
1
t
∂
S
j
∂
S
j
−
1
=
∏
j
=
k
+
1
t
tanh
′
W
s
\prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}=\prod_{j=k+1}^{t} \tanh ^{\prime} W_{s}
j=k+1∏t∂Sj−1∂Sj=j=k+1∏ttanh′Ws 对于LSTM同样也包含这样的一项,但是在LSTM中:
∏
j
=
k
+
1
t
∂
S
j
∂
S
j
−
1
=
∏
j
=
k
+
1
t
tanh
′
σ
(
W
f
X
t
+
b
f
)
\prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}=\prod_{j=k+1}^{t} \tanh ^{\prime} \sigma\left(W_{f} X_{t}+b_{f}\right)
j=k+1∏t∂Sj−1∂Sj=j=k+1∏ttanh′σ(WfXt+bf) 假设
Z
=
tanh
′
(
x
)
σ
(
y
)
Z=\tanh ^{\prime}(x) \sigma(y)
Z=tanh′(x)σ(y),则
Z
Z
Z的函数图像如下图所示:
可以看到该函数值基本上不是0就是1。
传统RNN的求偏导过程:
∂
L
3
∂
W
s
=
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
3
∂
W
s
+
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
3
∂
S
2
∂
S
2
∂
W
s
+
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
3
∂
S
2
∂
S
2
∂
S
1
∂
S
1
∂
W
s
\frac{\partial L_{3}}{\partial W_{s}}=\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial S_{1}} \frac{\partial S_{1}}{\partial W_{s}}
∂Ws∂L3=∂O3∂L3∂S3∂O3∂Ws∂S3+∂O3∂L3∂S3∂O3∂S2∂S3∂Ws∂S2+∂O3∂L3∂S3∂O3∂S2∂S3∂S1∂S2∂Ws∂S1
在LSTM中为:
∂
L
3
∂
W
s
=
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
3
∂
W
s
+
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
2
∂
W
s
+
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
1
∂
W
s
\frac{\partial L_{3}}{\partial W_{s}}=\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{2}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{1}}{\partial W_{s}}
∂Ws∂L3=∂O3∂L3∂S3∂O3∂Ws∂S3+∂O3∂L3∂S3∂O3∂Ws∂S2+∂O3∂L3∂S3∂O3∂Ws∂S1
因为
∏
j
=
k
+
1
t
∂
S
j
∂
S
j
−
1
=
∏
j
=
k
+
1
t
tanh
′
σ
(
W
f
X
t
+
b
f
)
≈
0
∣
1
\prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}=\prod_{j=k+1}^{t} \tanh ^{\prime} \sigma\left(W_{f} X_{t}+b_{f}\right) \approx 0 | 1
j=k+1∏t∂Sj−1∂Sj=j=k+1∏ttanh′σ(WfXt+bf)≈0∣1
这样就解决了传统RNN中梯度消失的问题。