一、RNN梯度更新过程
对于循环神经网络,在训练语言模型或序列标注任务中,每一个隐层输出与实际输出都对于产生一个损失函数
J
(
θ
)
\displaystyle J( \theta )
J(θ)。如
J
3
(
θ
)
\displaystyle J^{3}( \theta )
J3(θ):loss as time 3表示在第3时刻的损失。RNN采用基于时间的反向传播算法BPTT(Back Propagation Trough Time)算法不断对参数进行优化,使损失达到最优。
基于复合函数求导法则:
∂
J
(
4
)
(
θ
)
∂
h
(
1
)
=
∂
J
(
4
)
(
θ
)
∂
h
(
4
)
∂
h
(
4
)
∂
h
(
1
)
=
∂
J
(
4
)
(
θ
)
∂
h
(
4
)
∂
h
(
4
)
∂
h
3
∂
h
(
3
)
∂
h
(
2
)
∂
h
(
2
)
∂
h
(
1
)
\displaystyle \frac{\partial J^{( 4)}( \theta )}{\partial h^{( 1)}} =\frac{\partial J^{( 4)}( \theta )}{\partial h^{( 4)}}\frac{\partial h^{( 4)}}{\partial h^{( 1)}} =\frac{\partial J^{( 4)}( \theta )}{\partial h^{( 4)}}\frac{\partial h^{( 4)}}{\partial h^{3}}\frac{\partial h^{( 3)}}{\partial h^{( 2)}}\frac{\partial h^{( 2)}}{\partial h^{( 1)}}
∂h(1)∂J(4)(θ)=∂h(4)∂J(4)(θ)∂h(1)∂h(4)=∂h(4)∂J(4)(θ)∂h3∂h(4)∂h(2)∂h(3)∂h(1)∂h(2)
其中:
h
t
=
σ
(
w
h
h
⋅
h
t
−
1
+
w
x
h
⋅
x
t
+
b
1
)
\displaystyle h_{t} =\sigma ( w_{hh} \cdot h_{t-1} +w_{xh} \cdot x_{t} +b_{1})
ht=σ(whh⋅ht−1+wxh⋅xt+b1)
∂
h
t
∂
h
t
−
1
=
σ
′
(
w
h
h
⋅
h
t
−
1
+
w
x
h
⋅
x
t
+
b
1
)
⋅
w
h
h
\displaystyle \frac{\partial h_{t}}{\partial h_{t-1}} =\sigma ^{'}( w_{hh} \cdot h_{t-1} +w_{xh} \cdot x_{t} +b_{1}) \cdot w_{hh}
∂ht−1∂ht=σ′(whh⋅ht−1+wxh⋅xt+b1)⋅whh
对上式进行整理:
∂
J
(
i
)
(
θ
)
∂
h
j
=
∂
J
(
i
)
(
θ
)
∂
h
i
⋅
∏
i
<
t
⩽
i
∂
h
t
∂
h
t
−
1
=
∂
J
(
i
)
(
θ
)
∂
h
i
∏
i
<
t
⩽
i
σ
′
(
w
h
h
⋅
h
t
−
1
+
w
x
h
⋅
x
t
+
b
1
)
⋅
w
h
h
=
∂
J
(
i
)
(
θ
)
∂
h
i
w
h
h
i
−
j
∏
i
<
t
⩽
i
σ
′
(
w
h
h
⋅
h
t
−
1
+
w
x
h
⋅
x
t
+
b
1
)
\displaystyle \frac{\partial J^{( i)}( \theta )}{\partial h_{j}} =\frac{\partial J^{( i)}( \theta )}{\partial h_{i}} \cdot \prod _{i< t\leqslant i}\frac{\partial h_{t}}{\partial h_{t-1}} =\frac{\partial J^{( i)}( \theta )}{\partial h_{i}}\prod _{i< t\leqslant i} \sigma ^{'}( w_{hh} \cdot h_{t-1} +w_{xh} \cdot x_{t} +b_{1}) \cdot w_{hh}=\displaystyle \frac{\partial J^{( i)}( \theta )}{\partial h_{i}} w^{i-j}_{hh}\prod _{i< t\leqslant i} \sigma ^{'}( w_{hh} \cdot h_{t-1} +w_{xh} \cdot x_{t} +b_{1})
∂hj∂J(i)(θ)=∂hi∂J(i)(θ)⋅i<t⩽i∏∂ht−1∂ht=∂hi∂J(i)(θ)i<t⩽i∏σ′(whh⋅ht−1+wxh⋅xt+b1)⋅whh=∂hi∂J(i)(θ)whhi−ji<t⩽i∏σ′(whh⋅ht−1+wxh⋅xt+b1)
2、梯度消失和梯度爆炸如何产生的
对上式两边求范式:
∥
∂
J
(
i
)
(
θ
)
∂
h
j
∥
⩽
∥
∂
J
(
i
)
(
θ
)
∂
h
i
∥
⋅
∥
w
h
h
i
−
j
∥
⋅
∏
i
<
t
⩽
i
∥
σ
′
(
w
h
h
⋅
h
t
−
1
+
w
x
h
⋅
x
t
+
b
1
)
∥
\displaystyle \parallel \frac{\partial J^{( i)}( \theta )}{\partial h_{j}} \parallel \leqslant \parallel \frac{\partial J^{( i)}( \theta )}{\partial h_{i}} \parallel \cdot \parallel w^{i-j}_{hh} \parallel \cdot \prod _{i< t\leqslant i} \parallel \sigma ^{'}( w_{hh} \cdot h_{t-1} +w_{xh} \cdot x_{t} +b_{1}) \parallel
∥∂hj∂J(i)(θ)∥⩽∥∂hi∂J(i)(θ)∥⋅∥whhi−j∥⋅i<t⩽i∏∥σ′(whh⋅ht−1+wxh⋅xt+b1)∥
如果不等式左边小于1,对于时序问题,n个时刻的连乘,会导致梯度消失。
如果不等式左边大于1,对于时序问题,n个时刻的连乘,会导致梯度爆炸。
3、如何缓解梯度消失和梯度爆炸问题
对于参数
θ
\displaystyle \theta
θ的优化过程:
θ
(
t
+
1
)
=
θ
(
t
)
−
α
∇
θ
f
(
θ
(
t
)
)
\displaystyle \theta ^{( t+1)} =\theta ^{( t)} -\alpha \nabla _{\theta } f\left( \theta ^{( t)}\right)
θ(t+1)=θ(t)−α∇θf(θ(t))。
∇
θ
f
(
θ
(
t
)
)
\displaystyle \nabla _{\theta } f\left( \theta ^{( t)}\right)
∇θf(θ(t))为梯度,当它足够大时,就会引起梯度爆炸。梯度爆炸相对于梯度消失较好解决。下面介绍几种解决方法:
(1)梯度裁剪(Gradient Clipping)
梯度剪切这个方案主要是针对梯度爆炸提出的,其思想是设置一个梯度剪切阈值,然后更新梯度的时候,如果梯度超过这个阈值,那么就将其强制限制在这个范围之内。这可以防止梯度爆炸。
- 令 ∇ θ f ( θ ( t ) ) \displaystyle \nabla _{\theta } f\left( \theta ^{( t)}\right) ∇θf(θ(t))等于 g ( t ) \displaystyle g^{( t)} g(t):
- i f ∥ g ( t ) ∥ ⩾ t h r e s h o l d \displaystyle if\ \parallel g^{( t)} \parallel \geqslant threshold if ∥g(t)∥⩾threshold:
- t h e n then then : g ( t ) = g ( t ) ⋅ t h r e s h o l d ∥ g ( t ) ∥ \displaystyle g^{( t)} =g^{( t)} \cdot \frac{threshold}{\parallel g^{( t)} \parallel } g(t)=g(t)⋅∥g(t)∥threshold------相当于在原来梯度的基础上做了一个归一化的操作。
- 用新的梯度再去进行梯度下降法,更新参数 θ \displaystyle \theta θ
(2)增加非饱和的激活函数(如 ReLU)
Relu的思想很简单,如果激活函数的导数为1,那么就不存在梯度消失爆炸的问题了,每层的网络都可以得到相同的更新速度,relu就这样应运而生。
从上图中,可以很容易看出,relu函数的导数在正数部分是恒等于1的,因此在深层网络中使用relu激活函数就不会导致梯度消失和爆炸的问题。但relu也有缺点由于负数部分恒为0,会导致一些神经元无法激活(可通过设置小学习率部分解决)。
(3)批规范化:
将输出信号x规范化到均值为0,方差为1,保证网络的稳定性。在BPTT中的应用,反向传播中,经过每一层的梯度会乘以该层的权重。
eg:征相传播,那么反向传播中,反向传播式子中有
w
w
w的存在,所以
w
w
w的大小影响了梯度的消失和爆炸,batchnorm就是通过对每一层的输出规范为均值和方差一致的方法,消除了
w
w
w带来的放大虽小的影响,进而解决梯度消失和梯度爆炸的问题
(4)LSTM:
全称是长短期记忆网络(long-short term memory networks),是不那么容易发生梯度消失的,主要原因在于LSTM内部复杂的“门”。通过不同的门去控制整个信息流,选择性的忘记不需要记忆的信息。避免梯度消失或梯度爆炸问题。使得RNN可以捕捉到有效的更长的信息。
(5)还有,预训练加微调、残差结构等方法,因不太使用,这里不展开说明。