本文参考了李宏毅机器学习视频、RNN梯度消失和爆炸的原因、LSTM如何解决梯度消失问题(很清晰但要梯子)、神经网络中存在的问题、LSTM原理详解(这篇也不错但是有点点拼拼凑凑)
【NLP基础理论】RNN 中简单介绍了RNN的基础理论,如果对RNN还没有认知的可以先看这篇。
目录
LSTM现在都已经成为一个标准RNN,大家说RNN多半指的是LSTM,而最开始的RNN多称为 Simple RNN。所以本文主要是对于SimpleRNN为什么会存在 梯度消失/爆炸问题进行说明。
RNN 随着epoch增加,通常情况下是 loss 越来越小(下图蓝色),但是有的时候 loss 抖动的特别厉害,然后爆掉(下图绿色),这就是发生了梯度消失/爆炸。
直观感受梯度消失和爆炸(特例)
下方是一个极其简单的RNN,1000 个输入(时间步),除了第一个输入为 1 以外,其余均为 0。中间RNN函数为线性函数 y = x,没有bias,输入和输出的参数都为 1,输入给下一个 RNN 的参数为 w。那么根据公式
s
i
=
w
s
i
−
1
+
1
⋅
x
i
y
i
=
σ
(
1
⋅
s
i
)
s_i=ws_{i-1}+1\cdot x_i \\ y_i=\sigma(1\cdot s_i)
si=wsi−1+1⋅xiyi=σ(1⋅si)
可以得到最后
y
1000
y_{1000}
y1000 的输出为
w
999
w^{999}
w999。
当我们尝试去改变参数
w
w
w 的时候,我们会发现,只要参数有一点变化,对于长句子(长输入)来说,最后一向的输出会有很大的变化!
比如
w
w
w 从 1 变成 1.01,那么 最后的输出会从 1 变成约 20000,当
w
w
w 从 1 变成小于 1 的数字时,最后输出直接约等于 0。
梯度,我们也可以理解为 直观地看到参数
w
w
w 的变化会引起最后输出多大的变化。
因此上图中绿色部分
∂
L
∂
w
\frac{\partial L}{\partial w}
∂w∂L 值较大,而黄色部分值较小。但我们可以通过对于梯度较大的地方用较小的learning rate来对冲,反之亦然。可是这很难解决,因为RNN这个性质会导致梯度时大时小,error surface就会非常崎岖。
数学感受梯度消失和梯度爆炸
简单回忆 反向传播(BP) 的流程:
这是一个简单的神经网络,其中
θ
21
(
2
)
\theta_{21}^{(2)}
θ21(2) 代表第 2 层参数链接了前面第 2 个神经元和后面第 1 个神经元的参数;
z
1
1
,
a
1
1
z^1_{\ 1},a^1_{\ 1}
z 11,a 11分别表示第一层第一个激活函数的输入值和输出值。此处所有激活函数为sigmoid函数。
- 误差函数 E = 1 2 ( y − y ^ ) 2 E=\frac{1}{2}(y-\hat{y})^2 E=21(y−y^)2
- 参数迭代公式
θ j k ( l ) ← θ j k ( l ) + Δ θ j k ( l ) w h e r e , θ j k ( l ) = − η ∂ E ∂ θ j k ( l ) = η δ k ( l ) a j ( l − 1 ) \theta_{jk}^{(l)} \leftarrow \theta_{jk}^{(l)} + \Delta \theta_{jk}^{(l)}\\ where,\theta_{jk}^{(l)}=-\eta\frac{\partial E}{\partial \theta_{jk}^{(l)}} = \eta \delta^{(l)}_ka^{(l-1)}_j θjk(l)←θjk(l)+Δθjk(l)where,θjk(l)=−η∂θjk(l)∂E=ηδk(l)aj(l−1) - 以
θ
11
(
2
)
\theta^{(2)}_{11}
θ11(2) 为例,最后一层参数更新公式为:
E = 1 2 ( y − a 1 ( 2 ) ) 2 δ k ( l ) = ( 1 − σ ( z k ( l ) ) ) σ ( z k ( l ) ) ( y − a k ( l ) ) E=\frac{1}{2}(y-a_1^{(2)})^2\\ \delta_k^{(l)}=(1-\sigma(z^{(l)}_k))\sigma(z^{(l)}_k)(y-a_k^{(l)}) E=21(y−a1(2))2δk(l)=(1−σ(zk(l)))σ(zk(l))(y−ak(l))
推导: ∂ E ∂ θ 11 ( 2 ) = ∂ E ∂ a 1 2 ⋅ ∂ a 1 2 ∂ z 1 2 ⋅ ∂ z 1 2 ∂ θ 11 ( 2 ) \frac{\partial E}{\partial \theta_{11}^{(2)}}=\frac{\partial E}{\partial a^2_{\ 1}}\cdot \frac{\partial a^2_{\ 1}}{\partial z^2_{\ 1}}\cdot \frac{\partial z^2_{\ 1}}{\partial \theta_{11}^{(2)}} ∂θ11(2)∂E=∂a 12∂E⋅∂z 12∂a 12⋅∂θ11(2)∂z 12
其中 E = 1 2 ( y − a 1 2 ) 2 → ∂ E ∂ a 1 2 = − y + a 1 2 a 1 2 = 1 1 + e − z 1 2 → ∂ a 1 2 ∂ z 1 2 = e − z 1 2 ( 1 + e − z 1 2 ) 2 = ( 1 − σ ( z 1 2 ) ) σ ( z 1 2 ) z 1 2 = a 0 1 ⋅ θ 01 ( 2 ) + a 1 1 ⋅ θ 11 ( 2 ) + a 2 1 ⋅ θ 21 ( 2 ) → ∂ z 1 2 ∂ θ 11 ( 2 ) = a 1 1 E=\frac{1}{2}(y-a^2_{\ 1})^2 \rightarrow \frac{\partial E}{\partial a^2_{\ 1}}=-y+a^2_{\ 1}\\ a^2_{\ 1}=\frac{1}{1+e^{-z^2_{\ 1}}} \rightarrow \frac{\partial a^2_{\ 1}}{\partial z^2_{\ 1}}=\frac{e^{-z^2_{\ 1}}}{(1+e^{-z^2_{\ 1}})^2}=(1-\sigma(z^2_{\ 1}))\sigma(z^2_{\ 1})\\ z^2_{\ 1}=a^1_{\ 0} \cdot \theta_{01}^{(2)}+a^1_{\ 1} \cdot \theta_{11}^{(2)}+a^1_{\ 2}\cdot \theta_{21}^{(2)} \rightarrow \frac{\partial z^2_{\ 1}}{\partial \theta_{11}^{(2)}}=a^1_{\ 1} E=21(y−a 12)2→∂a 12∂E=−y+a 12a 12=1+e−z 121→∂z 12∂a 12=(1+e−z 12)2e−z 12=(1−σ(z 12))σ(z 12)z 12=a 01⋅θ01(2)+a 11⋅θ11(2)+a 21⋅θ21(2)→∂θ11(2)∂z 12=a 11 - 以
θ
11
(
1
)
\theta ^{(1)}_{11}
θ11(1) 其它层的参数更新公式为:
δ k ( l ) = ( 1 − σ ( z k ( l ) ) ) σ ( z k ( l ) ) θ k 1 ( l + 1 ) δ 1 ( l + 1 ) \delta_k^{(l)}=(1-\sigma(z^{(l)}_k))\sigma(z^{(l)}_k)\theta_{k1}^{(l+1)}\delta_1^{(l+1)} δk(l)=(1−σ(zk(l)))σ(zk(l))θk1(l+1)δ1(l+1)
推导: ∂ E ∂ θ 11 ( 2 ) = ∂ E ∂ a 1 2 ⋅ ∂ a 1 2 ∂ z 1 2 ⋅ ∂ z 1 2 ∂ a 1 2 ⋅ ∂ a 1 1 ∂ z 1 1 ⋅ ∂ z 1 1 ∂ θ 11 ( 1 ) \frac{\partial E}{\partial \theta_{11}^{(2)}}=\frac{\partial E}{\partial a^2_{\ 1}}\cdot \frac{\partial a^2_{\ 1}}{\partial z^2_{\ 1}}\cdot \frac{\partial z^2_{\ 1}}{\partial a^2_{\ 1}}\cdot \frac{\partial a^1_{\ 1}}{\partial z^1_{\ 1}}\cdot \frac{\partial z^1_{\ 1}}{\partial \theta ^{(1)}_{11}} ∂θ11(2)∂E=∂a 12∂E⋅∂z 12∂a 12⋅∂a 12∂z 12⋅∂z 11∂a 11⋅∂θ11(1)∂z 11
这个推导没有出现一个神经元会有两个 E E E 共同带来的误差。如果有多个的话,这要把每个 E E E 的值相加即可。
题外话:基本上所有的均方误差损失函数(MSE) 都是以下形式表现:
L
(
Y
∣
f
(
x
)
)
=
1
N
∑
i
=
1
N
(
Y
i
−
f
(
x
i
)
)
2
L(Y|f(x))=\frac{1}{N}\sum_{i=1}^{N}(Y_i-f(x_i))^2
L(Y∣f(x))=N1i=1∑N(Yi−f(xi))2
但后来突然发现还有将
N
N
N变为
2
N
2N
2N作为分母使用的(参考微软github)
题外话结束
简单回忆 SimpleRNN 模型:
S
i
=
t
a
n
h
(
W
s
S
i
−
1
+
W
x
x
i
+
b
1
)
O
i
=
σ
(
W
o
S
i
+
b
2
)
S_i = tanh(W_sS_{i-1}+W_xx_i+b_1)\\ O_i = \sigma (W_oS_i+b_2)
Si=tanh(WsSi−1+Wxxi+b1)Oi=σ(WoSi+b2)
好的,回忆完了,RNN会有多个输出,这里如果要使用BP的话一定是要考虑到时间的,因此RNN的BP叫做时间反向传播-BPTT(Back Propagation Through Time)
开始BPTT
(参考链接)
再从最简单的例子开始,下图是只有三个输入的RNN,没有任何激活函数:
s
1
=
w
x
x
1
+
w
s
s
0
+
b
1
,
y
1
=
w
y
s
1
+
b
2
s
2
=
w
x
x
2
+
w
s
s
1
+
b
1
,
y
2
=
w
y
s
2
+
b
2
s
3
=
w
x
x
2
+
w
s
s
2
+
b
1
,
y
2
=
w
y
s
3
+
b
2
s_1 = w_xx_1+w_ss_0+b_1,y_1=w_ys_1+b_2\\ s_2 = w_xx_2+w_ss_1+b_1,y_2=w_ys_2+b_2\\ s_3 = w_xx_2+w_ss_2+b_1,y_2=w_ys_3+b_2
s1=wxx1+wss0+b1,y1=wys1+b2s2=wxx2+wss1+b1,y2=wys2+b2s3=wxx2+wss2+b1,y2=wys3+b2
那在
t
=
3
t=3
t=3 的时刻,误差函数为
E
3
=
1
2
(
Y
3
−
y
3
)
2
E_3=\frac{1}{2}(Y_3-y_3)^2
E3=21(Y3−y3)2。
一次训练下所有的误差值为单个误差之和:
E
=
∑
t
=
0
T
E
t
E = \sum_{t=0}^{T}E_t
E=∑t=0TEt
而我们的目标是去更新所有的参数
w
x
,
w
s
,
w
y
w_x,w_s,w_y
wx,ws,wy,所以需要计算误差项的梯度。
RNN中误差项的梯度并更新参数:
∂
E
∂
W
=
∑
t
=
1
T
∂
E
t
∂
W
W
←
W
−
η
∂
E
∂
W
\frac{\partial E}{\partial W} = \sum_{t=1}^T\frac{\partial E_t}{\partial W}\\ W \leftarrow W - \eta \frac{\partial E}{\partial W}
∂W∂E=t=1∑T∂W∂EtW←W−η∂W∂E
此处,我们对
t
=
1
t=1
t=1 时刻入手开始更新:
∂
E
1
∂
w
y
=
∂
E
1
∂
y
1
∂
y
1
∂
w
y
∂
E
1
∂
w
x
=
∂
E
1
∂
y
1
∂
y
1
∂
s
1
∂
s
1
∂
w
x
∂
E
1
∂
w
s
=
∂
E
1
∂
y
1
∂
y
1
∂
s
1
∂
s
1
∂
w
s
\frac{\partial E_1}{\partial w_y}=\frac{\partial E_1}{\partial y_1}\frac{\partial y_1}{\partial w_y}\\ \frac{\partial E_1}{\partial w_x}=\frac{\partial E_1}{\partial y_1}\frac{\partial y_1}{\partial s_1}\frac{\partial s_1}{\partial w_x}\\ \frac{\partial E_1}{\partial w_s}=\frac{\partial E_1}{\partial y_1}\frac{\partial y_1}{\partial s_1}\frac{\partial s_1}{\partial w_s}\\
∂wy∂E1=∂y1∂E1∂wy∂y1∂wx∂E1=∂y1∂E1∂s1∂y1∂wx∂s1∂ws∂E1=∂y1∂E1∂s1∂y1∂ws∂s1
如果从 t=3 时刻开始,那么就需要每一次向后传递时,分一部分给
w
x
w_x
wx再分一部分错误给后面。
∂
E
∂
w
x
=
∂
E
3
∂
y
3
∂
y
3
∂
s
3
∂
s
3
∂
w
x
+
∂
E
3
∂
y
3
∂
y
3
∂
s
3
∂
s
3
∂
s
2
∂
s
2
∂
w
x
+
∂
E
3
∂
y
3
∂
y
3
∂
s
3
∂
s
3
∂
s
2
∂
s
2
∂
s
1
∂
s
1
∂
w
x
∂
E
∂
w
s
=
∂
E
3
∂
y
3
∂
y
3
∂
s
1
∂
s
3
∂
w
s
+
∂
E
3
∂
y
3
∂
y
3
∂
s
3
∂
s
3
∂
s
2
∂
s
2
∂
w
s
+
∂
E
3
∂
y
3
∂
y
3
∂
s
3
∂
s
3
∂
s
2
∂
s
2
∂
s
1
∂
s
1
∂
w
s
(
因
为
s
3
是
包
含
了
s
2
和
s
1
的
)
\frac{\partial E}{\partial w_x}=\frac{\partial E_3}{\partial y_3}\frac{\partial y_3}{\partial s_3}\frac{\partial s_3}{\partial w_x}+\frac{\partial E_3}{\partial y_3}\frac{\partial y_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial w_x}+\frac{\partial E_3}{\partial y_3}\frac{\partial y_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial s_1}\frac{\partial s_1}{\partial w_x}\\ \frac{\partial E}{\partial w_s}=\frac{\partial E_3}{\partial y_3}\frac{\partial y_3}{\partial s_1}\frac{\partial s_3}{\partial w_s}+\frac{\partial E_3}{\partial y_3}\frac{\partial y_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial w_s}+\frac{\partial E_3}{\partial y_3}\frac{\partial y_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial s_1}\frac{\partial s_1}{\partial w_s}\\(因为s_3是包含了s_2和s_1的)
∂wx∂E=∂y3∂E3∂s3∂y3∂wx∂s3+∂y3∂E3∂s3∂y3∂s2∂s3∂wx∂s2+∂y3∂E3∂s3∂y3∂s2∂s3∂s1∂s2∂wx∂s1∂ws∂E=∂y3∂E3∂s1∂y3∂ws∂s3+∂y3∂E3∂s3∂y3∂s2∂s3∂ws∂s2+∂y3∂E3∂s3∂y3∂s2∂s3∂s1∂s2∂ws∂s1(因为s3是包含了s2和s1的)
对上述偏导公式进行总结,得出所有时刻的梯度之和:
∂
E
∂
w
x
=
∑
k
=
0
t
∂
E
t
∂
y
t
∂
y
t
∂
s
t
(
∏
j
=
k
+
1
t
∂
s
j
∂
s
j
−
1
)
∂
s
k
∂
w
x
\frac{\partial E}{\partial w_x}=\sum_{k=0}^t\frac{\partial E_t}{\partial y_t}\frac{\partial y_t}{\partial s_t}(\prod_{j=k+1}^t\frac{\partial s_j}{\partial s_{j-1}})\frac{\partial s_k}{\partial w_x}
∂wx∂E=k=0∑t∂yt∂Et∂st∂yt(j=k+1∏t∂sj−1∂sj)∂wx∂sk
w
s
w_s
ws 同上。
因为上述式子是假设没有任何激活函数,下式是任意时刻的梯度传递到时间步1时候的公式:
∂
E
k
∂
w
x
=
∂
E
k
∂
y
k
∂
y
k
∂
s
k
(
∏
t
=
2
k
∂
s
t
∂
s
t
−
1
)
∂
s
1
∂
w
x
\frac{\partial E_k}{\partial w_x}=\frac{\partial E_k}{\partial y_k}\frac{\partial y_k}{\partial s_k}(\prod_{t=2}^k\frac{\partial s_t}{\partial s_{t-1}})\frac{\partial s_1}{\partial w_x}
∂wx∂Ek=∂yk∂Ek∂sk∂yk(t=2∏k∂st−1∂st)∂wx∂s1
因此,
在没有任何激活函数的情况下
∏
j
=
k
+
1
t
∂
s
j
∂
s
j
−
1
\prod_{j=k+1}^t\frac{\partial s_j}{\partial s_{j-1}}
∏j=k+1t∂sj−1∂sj 是
t
−
k
−
1
t-k-1
t−k−1个
w
s
w_s
ws相乘。那么
w
s
w_s
ws的大小就会影响梯度爆炸还是消失。
若有激活函数 :
s
j
=
t
a
n
h
(
w
s
s
j
−
1
+
w
x
x
j
)
s_j=tanh(w_ss_{j-1}+w_xx_j)
sj=tanh(wssj−1+wxxj)
求偏导则先求
t
a
n
h
(
x
)
tanh(x)
tanh(x)的导,再求
x
x
x 的导:
∂
s
j
∂
s
j
−
1
=
t
a
n
h
′
(
w
s
s
j
−
1
+
w
x
x
j
)
⋅
∂
∂
s
j
−
1
[
w
s
s
j
−
1
+
w
x
x
j
]
=
t
a
n
h
′
(
w
s
s
j
−
1
+
w
x
x
j
)
⋅
w
s
\frac{\partial s_j}{\partial s_{j-1}}=tanh'(w_ss_{j-1}+w_xx_j)\cdot \frac{\partial}{\partial s_{j-1}}[w_ss_{j-1}+w_xx_j]\\ =tanh'(w_ss_{j-1}+w_xx_j)\cdot w_s
∂sj−1∂sj=tanh′(wssj−1+wxxj)⋅∂sj−1∂[wssj−1+wxxj]=tanh′(wssj−1+wxxj)⋅ws
t
a
n
h
′
tanh'
tanh′ 取值在[0,1],后面还是有
w
s
w_s
ws !可恶。
(补充: 激活函数
t
a
n
h
(
x
)
=
e
x
−
e
−
x
e
x
+
e
−
x
tanh(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}}
tanh(x)=ex+e−xex−e−x)
重点!!!
所以到这里就可以发现,只要我时间步长够长,就会有越来越多的
w
s
w_s
ws 相乘,如果说
w
s
w_s
ws 正好也在 [0,1] ,配合上
t
a
n
h
′
tanh'
tanh′,我们的梯度就消失了(所以我们可以认为是
t
a
n
h
′
tanh'
tanh′和
w
s
w_s
ws两者一起导致梯度消失)。
但当然,因为变量终究还是
w
s
w_s
ws (可能会大于1,但
t
a
n
h
′
tanh'
tanh′肯定小于等于1),如果
w
s
w_s
ws 大的能抵消掉 那么多个
t
a
n
h
′
tanh'
tanh′相乘 的时候,就会造成梯度爆炸。
如果梯度消失,那么
∂
E
∂
W
=
∑
t
=
1
T
∂
E
t
∂
W
→
0
\frac{\partial E}{\partial W} = \sum_{t=1}^T\frac{\partial E_t}{\partial W} \rightarrow 0
∂W∂E=t=1∑T∂W∂Et→0
这也就导致我们的参数在合理的时间内就没怎么更新了。
W
←
W
−
η
∂
E
∂
W
≈
W
W \leftarrow W - \eta \frac{\partial E}{\partial W} \approx W
W←W−η∂W∂E≈W
干掉它:)
(参考链接)
现在知道了梯度消失和爆炸的问题就在于
∂
E
t
∂
w
x
=
∑
k
=
0
t
∂
E
t
∂
y
t
∂
y
t
∂
s
t
(
∏
j
=
k
+
1
t
∂
s
j
∂
s
j
−
1
)
∂
s
k
∂
w
x
\frac{\partial E_t}{\partial w_x}=\sum_{k=0}^t\frac{\partial E_t}{\partial y_t}\frac{\partial y_t}{\partial s_t}(\prod_{j=k+1}^t\frac{\partial s_j}{\partial s_{j-1}})\frac{\partial s_k}{\partial w_x}
∂wx∂Et=k=0∑t∂yt∂Et∂st∂yt(j=k+1∏t∂sj−1∂sj)∂wx∂sk
中的
∏
j
=
k
+
1
t
∂
s
j
∂
s
j
−
1
\prod_{j=k+1}^t\frac{\partial s_j}{\partial s_{j-1}}
∏j=k+1t∂sj−1∂sj,最直观的想法就是让它乘来乘去一直约为 1 或者 一直约为 0 ,这样就不会对整体的梯度有很大的影响。
LSTM 可以解决。(Clockwise RNN 和 SCRN 也可以,但这里不讲了)
简单回忆LSTM
(一个非常详细的LSTM介绍)
下面是三个时间步长的LSTM,
t
t
t时的输入是
[
h
t
−
1
,
x
t
]
[h_{t-1},x_{t}]
[ht−1,xt] 当前输入与上一个的输出相结合。三个橙色的
σ
\sigma
σ 函数是LSTM的三个gate,从左至右分别为遗忘门、输入门以及输出门。
遗忘门 根据新的输入,经过激活函数之后,得到一个 0 - 1 的值,这个值决定了过去的记忆 c t − 1 c_{t-1} ct−1 有多少被保留。:
f
t
=
σ
(
W
f
⋅
[
h
t
−
1
,
x
t
]
)
f_t = \sigma (W_f\cdot [h_{t-1},x_t])
ft=σ(Wf⋅[ht−1,xt])
输入门 控制了新输入有多少要被加入到记忆中,但这里的输出还需要配合上神经网络觉得新输入中有用的部分。
t
a
n
h
(
W
c
⋅
[
h
t
−
1
,
x
t
]
)
∗
σ
(
W
i
⋅
[
h
t
−
1
,
x
t
]
)
tanh(W_c\cdot [h_{t-1},x_t])*\sigma(W_i \cdot [h_{t-1},x_t])
tanh(Wc⋅[ht−1,xt])∗σ(Wi⋅[ht−1,xt])
这里会拆成以下两个矩阵进行对应相乘:
c
t
~
=
t
a
n
h
(
W
c
⋅
[
h
t
−
1
,
x
t
]
)
i
t
=
σ
(
W
i
⋅
[
h
t
−
1
,
x
t
]
)
\tilde{c_t} = tanh(W_c \cdot [h_{t-1},x_t])\\ i_t=\sigma(W_i \cdot[h_{t-1},x_t])
ct~=tanh(Wc⋅[ht−1,xt])it=σ(Wi⋅[ht−1,xt])
输出门 控制了有多少信息会被编辑到记忆细胞,作为下一个时间步的输入。
h
t
=
σ
(
W
o
⋅
[
h
t
−
1
,
x
t
]
)
∗
t
a
n
h
(
c
t
)
h_t = \sigma(W_o \cdot [h_{t-1},x_t])*tanh(c_t)
ht=σ(Wo⋅[ht−1,xt])∗tanh(ct)
记忆细胞 的内容就根据上述结果进行更新:
c
t
=
c
t
−
1
∗
f
t
+
c
t
~
∗
i
t
c_t=c_{t-1} * f_t + \tilde{c_t}*i_t
ct=ct−1∗ft+ct~∗it
题外话:看到这里你会发现为什么LSTM会用到两个
t
a
n
h
tanh
tanh 函数?
找了半天就这个链接中点赞第二多的比较有道理(点赞最多的那个我实在不懂tanh怎么就有他说的多阶导能很长时间不为0的表现)。
大概的意思就是
t
a
n
h
tanh
tanh的值在 [-1,1] 之间,但 sigmoid 在 [0,1] 之间,所以使用
t
a
n
h
tanh
tanh 可以生成负值。我了解到在神经网络中,像
t
a
n
h
tanh
tanh 这样的0中心激活函数可以加快收敛速度(关于0中心这篇不错),那可能也是这个原因我们在这里使用
t
a
n
h
tanh
tanh,应该是其他激活函数也可用(但这里不是百分百保证,最好对于激活函数的区别再了解一下)。
题外话结束
LSTM中的BPTT
总结一下 LSTM 里发生的公式们(
⋅
\cdot
⋅ 表示矩阵相乘,
∗
*
∗表示矩阵元素对应相乘):
f
t
=
σ
(
W
f
⋅
[
h
t
−
1
,
x
t
]
)
c
t
~
=
t
a
n
h
(
W
c
⋅
[
h
t
−
1
,
x
t
]
)
i
t
=
σ
(
W
i
⋅
[
h
t
−
1
,
x
t
]
)
o
t
=
σ
(
W
o
⋅
[
h
t
−
1
,
x
t
]
)
h
t
=
o
t
∗
t
a
n
h
(
c
t
)
,
输
出
c
t
=
c
t
−
1
∗
f
t
+
c
t
~
∗
i
t
,
记
忆
更
新
f_t = \sigma (W_f\cdot [h_{t-1},x_t])\\\tilde{c_t} = tanh(W_c \cdot [h_{t-1},x_t])\\ i_t=\sigma(W_i \cdot[h_{t-1},x_t])\\ o_t = \sigma(W_o \cdot [h_{t-1},x_t])\\ h_t = o_t*tanh(c_t) ,输出 \\c_t=c_{t-1} * f_t + \tilde{c_t}*i_t,记忆更新
ft=σ(Wf⋅[ht−1,xt])ct~=tanh(Wc⋅[ht−1,xt])it=σ(Wi⋅[ht−1,xt])ot=σ(Wo⋅[ht−1,xt])ht=ot∗tanh(ct),输出ct=ct−1∗ft+ct~∗it,记忆更新
现在我们放一个时间步长为3的 LSTM :
列出涉及到的式子(嵌套的就不打开了,太多太乱了):
c
1
=
c
0
∗
f
1
+
t
a
n
h
(
W
c
⋅
[
h
0
,
x
1
]
)
∗
i
1
h
1
=
o
1
∗
t
a
n
h
(
c
1
)
c
2
=
c
1
∗
f
2
+
t
a
n
h
(
W
c
⋅
[
h
1
,
x
2
]
)
∗
i
2
h
2
=
o
2
∗
t
a
n
h
(
c
2
)
c
3
=
c
2
∗
f
3
+
t
a
n
h
(
W
c
⋅
[
h
2
,
x
3
]
)
∗
i
3
h
3
=
o
3
∗
t
a
n
h
(
c
3
)
c_1 = c_0 * f_1+tanh(W_c\cdot [h_0,x_1])*i_1\\ h_1 = o_1*tanh(c_1)\\ c_2 = c_1 * f_2+tanh(W_c\cdot [h_1,x_2])*i_2\\ h_2 = o_2*tanh(c_2)\\ c_3 = c_2 * f_3+tanh(W_c\cdot [h_2,x_3])*i_3\\ h_3 = o_3*tanh(c_3)\\
c1=c0∗f1+tanh(Wc⋅[h0,x1])∗i1h1=o1∗tanh(c1)c2=c1∗f2+tanh(Wc⋅[h1,x2])∗i2h2=o2∗tanh(c2)c3=c2∗f3+tanh(Wc⋅[h2,x3])∗i3h3=o3∗tanh(c3)
现在要去更新参数,共计四个
W
f
,
W
c
,
W
i
,
W
o
W_f,W_c,W_i,W_o
Wf,Wc,Wi,Wo:
∂
E
3
∂
W
f
=
∂
E
3
∂
h
3
∂
h
3
∂
c
3
∂
c
3
∂
c
2
∂
c
2
∂
c
1
∂
c
1
∂
W
f
\frac{\partial E_3}{\partial W_f} = \frac{\partial E_3}{\partial h_3}\frac{\partial h_3}{\partial c_3}\frac{\partial c_3}{\partial c_2}\frac{\partial c_2}{\partial c_1}\frac{\partial c_1}{\partial W_f}
∂Wf∂E3=∂h3∂E3∂c3∂h3∂c2∂c3∂c1∂c2∂Wf∂c1
求偏导过程中,主要看函数中有
W
f
W_f
Wf的是哪部分,很快我们发现
c
i
c_i
ci函数中直观的包含了
W
c
,
W
i
W_c,W_i
Wc,Wi和
W
f
W_f
Wf,因此
W
c
,
W
i
W_c,W_i
Wc,Wi和
W
f
W_f
Wf的偏导公式都相同。
再看参数
W
o
W_o
Wo,它与
h
3
h_3
h3是直接关系。
∂
E
3
∂
W
o
=
∂
E
3
∂
h
3
∂
h
3
∂
W
o
\frac{\partial E_3}{\partial W_o} = \frac{\partial E_3}{\partial h_3}\frac{\partial h_3}{\partial W_o}
∂Wo∂E3=∂h3∂E3∂Wo∂h3
因此我们也可以总结出和RNN类似的公式,即任意时刻误差项公式:
∂
E
k
∂
W
c
=
∂
E
k
∂
h
k
∂
h
k
∂
c
k
(
∏
t
=
2
k
∂
c
t
∂
c
t
−
1
)
∂
c
1
∂
W
c
\frac{\partial E_k}{\partial W_c}=\frac{\partial E_k}{\partial h_k}\frac{\partial h_k}{\partial c_k}(\prod_{t=2}^k\frac{\partial c_t}{\partial c_{t-1}})\frac{\partial c_1}{\partial W_c}
∂Wc∂Ek=∂hk∂Ek∂ck∂hk(t=2∏k∂ct−1∂ct)∂Wc∂c1
W
c
,
W
f
,
W
i
W_c,W_f,W_i
Wc,Wf,Wi 同上。
接着看到
∂
c
t
∂
c
t
−
1
\frac{\partial c_t}{\partial c_{t-1}}
∂ct−1∂ct,举个t=2的例子:
c
2
=
c
1
∗
f
2
+
c
2
~
∗
i
2
∂
c
2
∂
c
1
=
∂
(
c
1
∗
f
2
)
∂
c
1
+
∂
(
c
2
~
∗
i
2
)
∂
c
1
=
c
1
⋅
f
2
′
+
c
1
′
⋅
f
2
+
c
2
~
⋅
i
2
′
+
c
2
~
′
⋅
i
2
=
c
1
⋅
∂
f
2
∂
c
1
+
f
2
+
c
2
~
⋅
∂
i
2
∂
c
1
+
i
2
⋅
∂
c
2
~
∂
c
1
=
c
1
⋅
∂
f
2
∂
h
1
∂
h
1
∂
c
1
+
f
2
+
c
2
~
⋅
∂
i
2
∂
h
1
∂
h
1
∂
c
1
+
i
2
⋅
∂
c
2
~
∂
h
1
∂
h
1
∂
c
1
=
c
1
⋅
σ
′
(
W
f
⋅
[
h
1
,
x
2
]
)
⋅
W
f
⋅
o
1
⋅
t
a
n
h
′
(
c
1
)
+
f
2
+
c
2
~
⋅
σ
′
(
W
i
⋅
[
h
1
,
x
2
]
)
⋅
W
i
⋅
o
1
⋅
t
a
n
h
′
(
c
1
)
+
i
2
⋅
σ
′
(
W
c
⋅
[
h
1
,
x
2
]
)
⋅
W
c
⋅
o
1
⋅
t
a
n
h
′
(
c
1
)
\begin{aligned} c_2 &= c_1 * f_2+\tilde{c_2}*i_2\\ \frac{\partial c_2}{\partial c_1} &= \frac{\partial (c_1*f_2)}{\partial c_1}+\frac{\partial (\tilde{c_2}*i_2)}{\partial c_1}\\ &=c_1 \cdot f_2'+c_1'\cdot f_2+\tilde{c_2}\cdot i_2'+\tilde{c_2}'\cdot i_2\\ &=c_1\cdot \frac{\partial f_2}{\partial c_1}+f_2+\tilde{c_2}\cdot \frac{\partial i_2}{\partial c_1}+i_2\cdot \frac{\partial \tilde{c_2}}{\partial c_1} \\ & = c_1\cdot \frac{\partial f_2}{\partial h_1}\frac{\partial h_1}{\partial c_1}+f_2+\tilde{c_2}\cdot \frac{\partial i_2}{\partial h_1}\frac{\partial h_1}{\partial c_1}+i_2\cdot \frac{\partial \tilde{c_2}}{\partial h_1}\frac{\partial h_1}{\partial c_1}\\ &=c_1\cdot \sigma'(W_f\cdot [h_1,x_2]) \cdot W_f \cdot o_1 \cdot tanh'(c_1)\\ &\ \ \ +f_2\\ &\ \ \ +\tilde{c_2}\cdot \sigma'(W_i\cdot [h_1,x_2]) \cdot W_i \cdot o_1 \cdot tanh'(c_1)\\ &\ \ \ + i_2 \cdot \sigma'(W_c\cdot [h_1,x_2]) \cdot W_c \cdot o_1 \cdot tanh'(c_1) \end{aligned}
c2∂c1∂c2=c1∗f2+c2~∗i2=∂c1∂(c1∗f2)+∂c1∂(c2~∗i2)=c1⋅f2′+c1′⋅f2+c2~⋅i2′+c2~′⋅i2=c1⋅∂c1∂f2+f2+c2~⋅∂c1∂i2+i2⋅∂c1∂c2~=c1⋅∂h1∂f2∂c1∂h1+f2+c2~⋅∂h1∂i2∂c1∂h1+i2⋅∂h1∂c2~∂c1∂h1=c1⋅σ′(Wf⋅[h1,x2])⋅Wf⋅o1⋅tanh′(c1) +f2 +c2~⋅σ′(Wi⋅[h1,x2])⋅Wi⋅o1⋅tanh′(c1) +i2⋅σ′(Wc⋅[h1,x2])⋅Wc⋅o1⋅tanh′(c1)
注意!我们之所以不直接把
c
1
c_1
c1提出来,是因为
f
2
f_2
f2 中是包含了
c
1
c_1
c1 的!!
f
2
f_2
f2包含了
h
1
h_1
h1,而
h
1
h_1
h1包含了
c
1
c_1
c1。
作为简单的记忆,我们就把
∂
c
2
∂
c
1
\frac{\partial c_2}{\partial c_1}
∂c1∂c2拆成了四项,除了
f
t
f_t
ft那一项,其他都是一个套路。
配上一个可视的BPTT
总结一下偏导公式:
∂
c
t
∂
c
t
−
1
=
∂
∂
c
t
−
1
[
c
t
−
1
∗
f
t
+
c
t
~
∗
i
t
]
=
∂
∂
c
t
−
1
[
c
t
−
1
∗
f
t
]
+
∂
∂
c
t
−
1
[
c
t
~
∗
i
t
]
=
∂
f
t
∂
c
t
−
1
⋅
c
t
−
1
+
∂
c
t
−
1
∂
c
t
−
1
⋅
f
t
+
∂
i
t
∂
c
t
−
1
⋅
c
t
~
+
∂
c
t
~
∂
c
t
−
1
⋅
i
t
=
∂
f
t
∂
h
t
−
1
⋅
∂
h
t
−
1
∂
c
t
−
1
⋅
c
t
−
1
+
∂
i
t
∂
h
t
−
1
⋅
∂
h
t
−
1
∂
c
t
−
1
⋅
c
t
~
+
∂
c
t
~
∂
h
t
−
1
⋅
∂
h
t
−
1
∂
c
t
−
1
⋅
i
t
=
σ
′
(
W
f
⋅
[
h
t
−
1
,
x
t
]
)
⋅
W
f
⋅
o
t
−
1
⋅
t
a
n
h
′
(
c
t
−
1
)
⋅
c
t
−
1
+
f
t
+
σ
′
(
W
i
⋅
[
h
t
−
1
,
x
t
]
)
⋅
W
i
⋅
o
t
−
1
⋅
t
a
n
h
′
(
c
t
−
1
)
⋅
c
t
~
+
σ
′
(
W
c
⋅
[
h
t
−
1
,
x
t
]
)
⋅
W
c
⋅
o
t
−
1
⋅
t
a
n
h
′
(
c
t
−
1
)
⋅
i
t
\begin{aligned} \frac{\partial c_t}{\partial c_{t-1}}&=\frac{\partial }{\partial c_{t-1}}[c_{t-1} * f_t + \tilde{c_t} * i_t]\\ &= \frac{\partial }{\partial c_{t-1}}[c_{t-1} * f_t]+\frac{\partial }{\partial c_{t-1}}[\tilde{c_{t}} * i_t]\\ &=\frac{\partial f_t}{\partial c_{t-1}}\cdot c_{t-1}+\frac{\partial c_{t-1}}{\partial c_{t-1}}\cdot f_t+\frac{\partial i_t}{\partial c_{t-1}}\cdot \tilde{c_t}+\frac{\partial \tilde{c_t}}{\partial c_{t-1}}\cdot i_t\\ &= \frac{\partial f_t}{\partial h_{t-1}} \cdot \frac{\partial h_{t-1}}{\partial c_{t-1}}\cdot c_{t-1}+\frac{\partial i_t}{\partial h_{t-1}}\cdot \frac{\partial h_{t-1}}{\partial c_{t-1}}\cdot\tilde{c_t}+\frac{\partial \tilde{c_t}}{\partial h_{t-1}}\cdot \frac{\partial h_{t-1}}{\partial c_{t-1}}\cdot i_t\\ &= \sigma'(W_f\cdot [h_{t-1},x_t]) \cdot W_f \cdot o_{t-1} \cdot tanh'(c_{t-1}) \cdot c_{t-1}\\ & \ \ \ +f_t\\ & \ \ \ + \sigma'(W_i\cdot [h_{t-1},x_t]) \cdot W_i \cdot o_{t-1} \cdot tanh'(c_{t-1}) \cdot \tilde{c_t}\\ & \ \ \ + \sigma'(W_c\cdot [h_{t-1},x_t]) \cdot W_c \cdot o_{t-1} \cdot tanh'(c_{t-1}) \cdot i_t \end{aligned}
∂ct−1∂ct=∂ct−1∂[ct−1∗ft+ct~∗it]=∂ct−1∂[ct−1∗ft]+∂ct−1∂[ct~∗it]=∂ct−1∂ft⋅ct−1+∂ct−1∂ct−1⋅ft+∂ct−1∂it⋅ct~+∂ct−1∂ct~⋅it=∂ht−1∂ft⋅∂ct−1∂ht−1⋅ct−1+∂ht−1∂it⋅∂ct−1∂ht−1⋅ct~+∂ht−1∂ct~⋅∂ct−1∂ht−1⋅it=σ′(Wf⋅[ht−1,xt])⋅Wf⋅ot−1⋅tanh′(ct−1)⋅ct−1 +ft +σ′(Wi⋅[ht−1,xt])⋅Wi⋅ot−1⋅tanh′(ct−1)⋅ct~ +σ′(Wc⋅[ht−1,xt])⋅Wc⋅ot−1⋅tanh′(ct−1)⋅it
如果把这四项分别用
A
,
B
,
C
,
D
A,B,C,D
A,B,C,D替代的话,公式就可以变成:
∂
c
t
∂
c
t
−
1
=
A
t
+
B
t
+
C
t
+
D
t
\frac{\partial c_t}{\partial c_{t-1}}=A_t+B_t+C_t+D_t
∂ct−1∂ct=At+Bt+Ct+Dt
把这个简洁的公式带入之前的误差公式:
∂
E
k
∂
W
=
∂
E
k
∂
h
k
∂
h
k
∂
c
k
(
∏
t
=
2
k
∂
c
t
∂
c
t
−
1
)
∂
c
1
∂
W
=
∂
E
k
∂
h
k
∂
h
k
∂
c
k
(
∏
t
=
2
k
[
A
t
+
B
t
+
C
t
+
D
t
]
)
∂
c
1
∂
W
\begin{aligned} \frac{\partial E_k}{\partial W}&=\frac{\partial E_k}{\partial h_k}\frac{\partial h_k}{\partial c_k}(\prod_{t=2}^k\frac{\partial c_t}{\partial c_{t-1}})\frac{\partial c_1}{\partial W}\\ &=\frac{\partial E_k}{\partial h_k}\frac{\partial h_k}{\partial c_k}(\prod_{t=2}^k[A_t+B_t+C_t+D_t])\frac{\partial c_1}{\partial W} \end{aligned}
∂W∂Ek=∂hk∂Ek∂ck∂hk(t=2∏k∂ct−1∂ct)∂W∂c1=∂hk∂Ek∂ck∂hk(t=2∏k[At+Bt+Ct+Dt])∂W∂c1
缓解梯度消失/爆炸问题
有连乘,那就说明有可能造成梯度消失和爆炸。上文也讲了 ∏ t = 2 k ∂ c t ∂ c t − 1 \prod_{t=2}^k\frac{\partial c_t}{\partial c_{t-1}} ∏t=2k∂ct−1∂ct里面有什么,总共四项,如果看的云里雾里也没事,因为那个 B t B_t Bt 你一定看的懂!因为 B t B_t Bt 只有一个内容 f t f_t ft,我们可以轻松地直观地通过他调整 f t f_t ft 的大小以适应其他三个项的值,然后是的连乘出来的结果不会非常小。
接下来我们看
f
t
f_t
ft 到底怎么能帮助我们。现在假设对某一个时间步
k
<
T
k<T
k<T,我们有:
∑
t
=
1
k
∂
E
t
∂
W
→
0
\sum_{t=1}^k \frac{\partial E_t}{\partial W}\rightarrow 0
t=1∑k∂W∂Et→0
然后为了梯度不消失,我们可以再时间步
k
+
1
k+1
k+1 找到一个合适的
W
f
W_f
Wf 使得:
∂
E
k
+
1
∂
W
↛
0
\frac{\partial E_{k+1}}{\partial W}\nrightarrow 0
∂W∂Ek+1↛0
由于遗忘门的激活函数和梯度项中大家都是相加的(A,B,C,D,加性结构),所以使得 LSTM 在任何时间步都能找到这样的
W
f
W_f
Wf 使得:
∑
t
=
1
k
+
1
∂
E
t
∂
W
↛
0
\sum_{t=1}^{k+1}\frac{\partial E_t}{\partial W}\nrightarrow 0
t=1∑k+1∂W∂Et↛0
这样梯度就不会消失了。
另一个重要的性质: 正如上文说到的加性结构,四个项可以相互平衡从而保证在反向传播的时候梯度值不会消失。
举个例子:假设 时间步
t
∈
{
2
,
3
,
.
.
.
,
k
}
t \in \{2,3,...,k\}
t∈{2,3,...,k},我们对梯度值中的四项设置一个相互平衡的值(从而保证梯度不消失):
A
t
≈
0.1
⃗
,
B
t
=
f
t
≈
0.7
⃗
,
C
t
≈
0.1
⃗
,
D
t
≈
0.1
⃗
A_t \approx \vec{0.1},B_t = f_t \approx \vec{0.7},C_t \approx \vec{0.1},D_t \approx \vec{0.1}
At≈0.1,Bt=ft≈0.7,Ct≈0.1,Dt≈0.1
带入连乘公式:
∏
t
=
2
k
[
A
t
+
B
t
+
C
t
+
D
t
]
≈
∏
t
=
2
k
[
0.1
⃗
+
0.7
⃗
+
0.1
⃗
+
0.1
⃗
]
≈
∏
t
=
2
k
1
⃗
↛
0
\begin{aligned} \prod_{t=2}^k[A_t+B_t+C_t+D_t] &\approx \prod_{t=2}^k[\vec{0.1}+\vec{0.7}+\vec{0.1}+\vec{0.1}]\\ & \approx \prod_{t=2}^k \vec{1} \nrightarrow 0\end{aligned}
t=2∏k[At+Bt+Ct+Dt]≈t=2∏k[0.1+0.7+0.1+0.1]≈t=2∏k1↛0
这时候就算是连乘,梯度也不会消失了。
所以,在 LSTM 中,遗忘门的存在,以及细胞状态梯度的加性特性,使网络能够以这样一种方式更新参数,即不同子梯度之间的平衡从而避免梯度消失。
但看到这,也就清楚了,因为我们都是正数相加,所以不能够避免梯度爆炸,当
A
,
C
,
D
A,C,D
A,C,D的数值很大的时候,
f
t
f_t
ft 也没办法去平衡防止梯度爆炸。