RNN经过许多阶段传播后的梯度倾向于消失(大部分情况)或爆炸(很少,但对优化过程影响很大)。【比如说对于 t = 1 , 2 t=1, 2 t=1,2时刻的输入向量的梯度 ∂ J ∂ x 2 = ∂ J ∂ h 2 U T \dfrac{\partial J}{\partial x_{2}}=\dfrac{\partial J}{\partial h_{2}} U^{T} ∂x2∂J=∂h2∂JUT,可以被 t = 4 t=4 t=4的 h 4 h_4 h4有效影响。但是对于 t = 99 , 100 t=99, 100 t=99,100的 h 99 , h 100 h_{99}, h_{100} h99,h100,可能就无法有效影响了。因为 h 1 h_1 h1依赖 S 1 S_1 S1, S 1 S_1 S1依赖 h 2 h_2 h2,一直递归下去。】
梯度消失或者爆炸
梯度爆炸:在深层的神经网络中,由于多个权重矩阵的相乘,会 出现很多如图所示的陡峭区域,当然也有可能会出现 很多非常平坦的区域。在这些陡峭的地方,Loss函数 的导数非常大,导致最终的梯度也很大,对参数进行更新后可能会导致参数的取值超出有效的取值范围,
梯度消失:而在那些非常平坦的地方,Loss的变化很小,这个时候梯度的值也会很小(可能趋近于0),导致参数的更新非常缓慢,甚至更新的方向都不明确。
回到上一篇博客中,我们定义一个简化的循环神经网络,该网络中的所有激活函数均为线性的,除了在每个时间步上共享的参数W以外,其它的权重矩阵均设为1,偏置项均设为0。我们又假设输入的序列中除了 x 0 = 1 x_0=1 x0=1, 其他输入的值为0, 如下图
那么,我们从前向传播的角度看,RNN的输出是关于权重矩阵
W
W
W的指数函数
h
0
=
1
h
1
=
W
h
2
=
W
2
⋮
h
t
=
W
t
\begin{array}{l} h_{0}=1 \\ h_{1}=W \\ h_{2}=W^{2} \\ \vdots \\ h_{t}=W^{t} \end{array}
h0=1h1=Wh2=W2⋮ht=Wt
当W的值大于1时,随着
t
t
t的增加,神经网络最终输出的值也成指数级增长,而当W的值小于1时,随着
t
t
t的值增加,神经网络最终的输出则会非常小。这两种结果分别是导致梯度爆炸和梯度消失的根本原因。从例子可以看到,循环神经网络中梯度消失和梯度爆炸问题产生的根本原因,是由于参数共享导致的。
如果从后向传播的角度看,也会有类似问题
∂
J
∂
S
t
−
1
=
∂
J
∂
O
t
−
1
V
T
+
∂
J
∂
h
t
W
T
∂
J
∂
h
t
=
∂
J
∂
S
t
d
S
t
d
h
t
⇒
∂
J
∂
S
t
−
1
=
∂
J
∂
O
t
−
1
V
T
+
∂
J
∂
S
t
d
S
t
d
h
t
W
T
≈
f
(
t
)
∂
J
∂
S
t
\begin{array}{l} &\dfrac{\partial J}{\partial S_{t-1}}=\dfrac{\partial J}{\partial O_{t-1}} V^{T}+\dfrac{\partial J}{\partial h_{t}} W^{T} \\ &\dfrac{\partial J}{\partial h_{t}}=\dfrac{\partial J}{\partial S_{t}} \dfrac{d S_{t}}{d h_{t}}\\ \Rightarrow & \dfrac{\partial J}{\partial S_{t-1}}=\dfrac{\partial J}{\partial O_{t-1}} V^{T}+\dfrac{\partial J}{\partial S_{t}} \dfrac{d S_{t}}{d h_{t}} W^{T} \approx f(t) \dfrac{\partial J}{\partial S_{t}} \end{array}
⇒∂St−1∂J=∂Ot−1∂JVT+∂ht∂JWT∂ht∂J=∂St∂JdhtdSt∂St−1∂J=∂Ot−1∂JVT+∂St∂JdhtdStWT≈f(t)∂St∂J
因此对于
t
t
t时刻的前
k
k
k时刻
∂
J
∂
S
t
−
k
=
f
1
(
t
)
f
2
(
t
)
⋯
f
k
(
t
)
∂
J
∂
S
t
\frac{\partial J}{\partial S_{t-k}}=f_{1}(t) f_{2}(t) \cdots f_{k}(t) \frac{\partial J}{\partial S_{t}}
∂St−k∂J=f1(t)f2(t)⋯fk(t)∂St∂J
这里如果忽略
d
S
t
d
h
t
\dfrac{d S_{t}}{d h_{t}}
dhtdSt, 那么自然又有与前向传播时类似的问题,
f
1
f
2
⋯
f
k
f_{1}f_{2}\cdots f_{k}
f1f2⋯fk与W的k次方相关,
f
1
(
t
)
f
2
(
t
)
⋯
f
k
(
t
)
∝
(
W
T
)
k
f_{1}(t) f_{2}(t) \cdots f_{k}(t) \propto (W^{T})^{k}
f1(t)f2(t)⋯fk(t)∝(WT)k, 因此会出现梯度消失或者梯度爆炸。比如假设
W
W
W的值远小于1, 显然随着
t
,
k
t, k
t,k的值增加,
∂
J
∂
S
t
−
k
\dfrac{\partial J}{\partial S_{t-k}}
∂St−k∂J会越来越小接近于0。
这里我们可以采用一种简单的方式解决,也就是在反向传播的过程中, 每隔k时间段就清除下一时刻传来的梯度。比如说,设定k=20, 在 t − 20 t-20 t−20时刻,梯度 ∂ J ∂ S t − k \frac{\partial J}{\partial S_{t-k}} ∂St−k∂J重新等于0, 那么 ∂ J ∂ S t − 20 = ∂ J ∂ O t − 20 V T + ∂ J ∂ h t − 19 W T = ∂ J ∂ O t − 20 V T \dfrac{\partial J}{\partial S_{t-20}}=\dfrac{\partial J}{\partial O_{t-20}} V^{T}+\dfrac{\partial J}{\partial h_{t-19}} W^{T} = \dfrac{\partial J}{\partial O_{t-20}} V^{T} ∂St−20∂J=∂Ot−20∂JVT+∂ht−19∂JWT=∂Ot−20∂JVT, 也就是只保留第一项。
另一个方法就是就是采用gradient clip。比如只允许[-100, 100]范围内的梯度,如果梯度值变成了500, 就强制变为100。这样也能减缓这个问题。