梯度消失和梯度爆炸

一、RNN梯度更新过程

对于循环神经网络,在训练语言模型或序列标注任务中,每一个隐层输出与实际输出都对于产生一个损失函数 J ( θ ) \displaystyle J( \theta ) J(θ)。如 J 3 ( θ ) \displaystyle J^{3}( \theta ) J3(θ):loss as time 3表示在第3时刻的损失。RNN采用基于时间的反向传播算法BPTT(Back Propagation Trough Time)算法不断对参数进行优化,使损失达到最优。
在这里插入图片描述
基于复合函数求导法则:
∂ J ( 4 ) ( θ ) ∂ h ( 1 ) = ∂ J ( 4 ) ( θ ) ∂ h ( 4 ) ∂ h ( 4 ) ∂ h ( 1 ) = ∂ J ( 4 ) ( θ ) ∂ h ( 4 ) ∂ h ( 4 ) ∂ h 3 ∂ h ( 3 ) ∂ h ( 2 ) ∂ h ( 2 ) ∂ h ( 1 ) \displaystyle \frac{\partial J^{( 4)}( \theta )}{\partial h^{( 1)}} =\frac{\partial J^{( 4)}( \theta )}{\partial h^{( 4)}}\frac{\partial h^{( 4)}}{\partial h^{( 1)}} =\frac{\partial J^{( 4)}( \theta )}{\partial h^{( 4)}}\frac{\partial h^{( 4)}}{\partial h^{3}}\frac{\partial h^{( 3)}}{\partial h^{( 2)}}\frac{\partial h^{( 2)}}{\partial h^{( 1)}} h(1)J(4)(θ)=h(4)J(4)(θ)h(1)h(4)=h(4)J(4)(θ)h3h(4)h(2)h(3)h(1)h(2)
其中:
h t = σ ( w h h ⋅ h t − 1 + w x h ⋅ x t + b 1 ) \displaystyle h_{t} =\sigma ( w_{hh} \cdot h_{t-1} +w_{xh} \cdot x_{t} +b_{1}) ht=σ(whhht1+wxhxt+b1)
∂ h t ∂ h t − 1 = σ ′ ( w h h ⋅ h t − 1 + w x h ⋅ x t + b 1 ) ⋅ w h h \displaystyle \frac{\partial h_{t}}{\partial h_{t-1}} =\sigma ^{'}( w_{hh} \cdot h_{t-1} +w_{xh} \cdot x_{t} +b_{1}) \cdot w_{hh} ht1ht=σ(whhht1+wxhxt+b1)whh
对上式进行整理:
∂ J ( i ) ( θ ) ∂ h j = ∂ J ( i ) ( θ ) ∂ h i ⋅ ∏ i < t ⩽ i ∂ h t ∂ h t − 1 = ∂ J ( i ) ( θ ) ∂ h i ∏ i < t ⩽ i σ ′ ( w h h ⋅ h t − 1 + w x h ⋅ x t + b 1 ) ⋅ w h h = ∂ J ( i ) ( θ ) ∂ h i w h h i − j ∏ i < t ⩽ i σ ′ ( w h h ⋅ h t − 1 + w x h ⋅ x t + b 1 ) \displaystyle \frac{\partial J^{( i)}( \theta )}{\partial h_{j}} =\frac{\partial J^{( i)}( \theta )}{\partial h_{i}} \cdot \prod _{i< t\leqslant i}\frac{\partial h_{t}}{\partial h_{t-1}} =\frac{\partial J^{( i)}( \theta )}{\partial h_{i}}\prod _{i< t\leqslant i} \sigma ^{'}( w_{hh} \cdot h_{t-1} +w_{xh} \cdot x_{t} +b_{1}) \cdot w_{hh}=\displaystyle \frac{\partial J^{( i)}( \theta )}{\partial h_{i}} w^{i-j}_{hh}\prod _{i< t\leqslant i} \sigma ^{'}( w_{hh} \cdot h_{t-1} +w_{xh} \cdot x_{t} +b_{1}) hjJ(i)(θ)=hiJ(i)(θ)i<tiht1ht=hiJ(i)(θ)i<tiσ(whhht1+wxhxt+b1)whh=hiJ(i)(θ)whhiji<tiσ(whhht1+wxhxt+b1)

2、梯度消失和梯度爆炸如何产生的

对上式两边求范式:
∥ ∂ J ( i ) ( θ ) ∂ h j ∥ ⩽ ∥ ∂ J ( i ) ( θ ) ∂ h i ∥ ⋅ ∥ w h h i − j ∥ ⋅ ∏ i < t ⩽ i ∥ σ ′ ( w h h ⋅ h t − 1 + w x h ⋅ x t + b 1 ) ∥ \displaystyle \parallel \frac{\partial J^{( i)}( \theta )}{\partial h_{j}} \parallel \leqslant \parallel \frac{\partial J^{( i)}( \theta )}{\partial h_{i}} \parallel \cdot \parallel w^{i-j}_{hh} \parallel \cdot \prod _{i< t\leqslant i} \parallel \sigma ^{'}( w_{hh} \cdot h_{t-1} +w_{xh} \cdot x_{t} +b_{1}) \parallel hjJ(i)(θ)hiJ(i)(θ)whhiji<tiσ(whhht1+wxhxt+b1)
如果不等式左边小于1,对于时序问题,n个时刻的连乘,会导致梯度消失。
如果不等式左边大于1,对于时序问题,n个时刻的连乘,会导致梯度爆炸。

3、如何缓解梯度消失和梯度爆炸问题

对于参数 θ \displaystyle \theta θ的优化过程:
θ ( t + 1 ) = θ ( t ) − α ∇ θ f ( θ ( t ) ) \displaystyle \theta ^{( t+1)} =\theta ^{( t)} -\alpha \nabla _{\theta } f\left( \theta ^{( t)}\right) θ(t+1)=θ(t)αθf(θ(t)) ∇ θ f ( θ ( t ) ) \displaystyle \nabla _{\theta } f\left( \theta ^{( t)}\right) θf(θ(t))为梯度,当它足够大时,就会引起梯度爆炸。梯度爆炸相对于梯度消失较好解决。下面介绍几种解决方法:
(1)梯度裁剪(Gradient Clipping)
梯度剪切这个方案主要是针对梯度爆炸提出的,其思想是设置一个梯度剪切阈值,然后更新梯度的时候,如果梯度超过这个阈值,那么就将其强制限制在这个范围之内。这可以防止梯度爆炸。

  • ∇ θ f ( θ ( t ) ) \displaystyle \nabla _{\theta } f\left( \theta ^{( t)}\right) θf(θ(t))等于 g ( t ) \displaystyle g^{( t)} g(t)
  • i f   ∥ g ( t ) ∥ ⩾ t h r e s h o l d \displaystyle if\ \parallel g^{( t)} \parallel \geqslant threshold if g(t)threshold
  • t h e n then then g ( t ) = g ( t ) ⋅ t h r e s h o l d ∥ g ( t ) ∥ \displaystyle g^{( t)} =g^{( t)} \cdot \frac{threshold}{\parallel g^{( t)} \parallel } g(t)=g(t)g(t)threshold------相当于在原来梯度的基础上做了一个归一化的操作。
  • 用新的梯度再去进行梯度下降法,更新参数 θ \displaystyle \theta θ

(2)增加非饱和的激活函数(如 ReLU)
Relu的思想很简单,如果激活函数的导数为1,那么就不存在梯度消失爆炸的问题了,每层的网络都可以得到相同的更新速度,relu就这样应运而生。
在这里插入图片描述
从上图中,可以很容易看出,relu函数的导数在正数部分是恒等于1的,因此在深层网络中使用relu激活函数就不会导致梯度消失和爆炸的问题。但relu也有缺点由于负数部分恒为0,会导致一些神经元无法激活(可通过设置小学习率部分解决)。
(3)批规范化:
将输出信号x规范化到均值为0,方差为1,保证网络的稳定性。在BPTT中的应用,反向传播中,经过每一层的梯度会乘以该层的权重。
eg:征相传播,那么反向传播中,反向传播式子中有 w w w的存在,所以 w w w的大小影响了梯度的消失和爆炸,batchnorm就是通过对每一层的输出规范为均值和方差一致的方法,消除了 w w w带来的放大虽小的影响,进而解决梯度消失和梯度爆炸的问题
(4)LSTM:
全称是长短期记忆网络(long-short term memory networks),是不那么容易发生梯度消失的,主要原因在于LSTM内部复杂的“门”。通过不同的门去控制整个信息流,选择性的忘记不需要记忆的信息。避免梯度消失或梯度爆炸问题。使得RNN可以捕捉到有效的更长的信息。
(5)还有,预训练加微调、残差结构等方法,因不太使用,这里不展开说明。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值