理论推导RNN梯度消失和梯度爆炸的原因

  RNN的提出是为了解决网络无法利用历史信息的问题,但由于RNN具有梯度消失和梯度爆炸的问题,导致RNN不能存储长期记忆。

网络结构

  首先来看RNN的结构,如下图1所示:

RNN结构

  上图的结构很好理解, x t x_{t} xt为网络输入, A A A为隐藏层, h t h_{t} ht为网络输出。既然我们想利用之前的历史信息,那我们就将网络在上一时刻的输出保存下来,作为当前时刻的输入,也就是上图中的反馈连接。我们将上图中的RNN结构按时序展开,如下图2所示:

RNN时序展开

   x 0 x_{0} x0~ x t x_{t} xt是网络在不同时刻的输入, h 0 h_{0} h0 ~ h t h_{t} ht是网络在不同时刻的输出,A是隐藏层。需要注意的是,上图中的RNN展开图是RNN按时序的展开图,并不是真正的拓扑结构,对于某一固定的时刻 t t t,RNN的结构就是图1;这是很多资料容易让人产生误解的地方。所以,图2中的那么多A其实是同一个隐藏层,这也就是RNN中的“参数共享”。当然,你也可以增加RNN的深度,即增加隐藏层,如下图3所示:

QiVFKA.md.png

如上图所示,纵向是增加网络深度,横向是增加时间步。

工作原理

  介绍了RNN的网络结构,下面来看RNN的工作过程。我们假设网络只有一个隐藏层,网络输入为 x x x,输出为 y y y,隐藏层状态为 h h h,如下图4所示,

QiKYI1.md.jpg

则在时刻 t t t有:

h t = f ( w i x + w h h t − 1 ) h_{t}=f(w_{i}x+w_{h}h_{t-1}) ht=f(wix+whht1)
y t = f ( w o h t ) y_{t}=f(w_{o}h_{t}) yt=f(woht)

上式中, f f f为激活函数,一般为 s i g m o i d sigmoid sigmoid t a n h tanh tanh

梯度消失与梯度爆炸

  了解了RNN的工作原理,下面我们就可以去分析RNN梯度消失和梯度爆炸的原因了。为了简化问题,只考虑三个时间步,如下图5所示:

Qiu63T.md.jpg

则有:

h 1 = f ( w i x 1 + w h h 0 ) , y 1 = f ( w o h 1 ) h_{1}=f(w_{i}x_{1}+w_{h}h_{0}) , y_{1}=f(w_{o}h_{1}) h1=f(wix1+whh0),y1=f(woh1)

h 2 = f ( w i x 2 + w h h 1 ) , y 2 = f ( w o h 2 ) h_{2}=f(w_{i}x_{2}+w_{h}h_{1}) , y_{2}=f(w_{o}h_{2}) h2=f(wix2+whh1),y2=f(woh2)

h 3 = f ( w i x 3 + w h h 2 ) , y 3 = f ( w o h 3 ) h_{3}=f(w_{i}x_{3}+w_{h}h_{2}) , y_{3}=f(w_{o}h_{3}) h3=f(wix3+whh2),y3=f(woh3)

RNN的损失函数为

L = ∑ t = 0 T L t = ∑ t = 0 T g ( y t ) L=\sum_{t=0}^{T}L_{t}=\sum_{t=0}^{T}g(y_{t}) L=t=0TLt=t=0Tg(yt)

L t L_{t} Lt t t t时刻输出的损失, g g g为网络的损失函数。根据链式求导法则,求L对各个参数的偏导即为参数更新的梯度。

先只考虑 L 3 L_{3} L3求偏导,有:

∂ L 3 ∂ w o = ∂ L 3 ∂ y 3 ∂ y 3 ∂ w o \frac{\partial L_{3}}{\partial w_{o}}=\frac{\partial L_{3}}{\partial y_{3}}\frac{\partial y_{3}}{\partial w_{o}} woL3=y3L3woy3

∂ L 3 ∂ w i = ∂ L 3 ∂ y 3 ∂ y 3 ∂ h 3 ∂ h 3 ∂ w i + ∂ L 3 ∂ y 3 ∂ y 3 ∂ h 3 ∂ h 3 ∂ h 2 ∂ h 2 ∂ w i + ∂ L 3 ∂ y 3 ∂ y 3 ∂ h 3 ∂ h 3 ∂ h 2 ∂ h 2 ∂ h 1 ∂ h 1 ∂ w i \frac{\partial L_{3}}{\partial w_{i}}=\frac{\partial L_{3}}{\partial y_{3}}\frac{\partial y_{3}}{\partial h_{3}}\frac{\partial h_{3}}{\partial w_{i}}+\frac{\partial L_{3}}{\partial y_{3}}\frac{\partial y_{3}}{\partial h_{3}}\frac{\partial h_{3}}{\partial h_{2}}\frac{\partial h_{2}}{\partial w_{i}}+\frac{\partial L_{3}}{\partial y_{3}}\frac{\partial y_{3}}{\partial h_{3}}\frac{\partial h_{3}}{\partial h_{2}}\frac{\partial h_{2}}{\partial h_{1}}\frac{\partial h_{1}}{\partial w_{i}} wiL3=y3L3h3y3wih3+y3L3h3y3h2h3wih2+y3L3h3y3h2h3h1h2wih1

∂ L 3 ∂ w h = ∂ L 3 ∂ y 3 ∂ y 3 ∂ h 3 ∂ h 3 ∂ w h + ∂ L 3 ∂ y 3 ∂ y 3 ∂ h 3 ∂ h 3 ∂ h 2 ∂ h 2 ∂ w h + ∂ L 3 ∂ y 3 ∂ y 3 ∂ h 3 ∂ h 3 ∂ h 2 ∂ h 2 ∂ h 1 ∂ h 1 ∂ w h \frac{\partial L_{3}}{\partial w_{h}}=\frac{\partial L_{3}}{\partial y_{3}}\frac{\partial y_{3}}{\partial h_{3}}\frac{\partial h_{3}}{\partial w_{h}}+\frac{\partial L_{3}}{\partial y_{3}}\frac{\partial y_{3}}{\partial h_{3}}\frac{\partial h_{3}}{\partial h_{2}}\frac{\partial h_{2}}{\partial w_{h}}+\frac{\partial L_{3}}{\partial y_{3}}\frac{\partial y_{3}}{\partial h_{3}}\frac{\partial h_{3}}{\partial h_{2}}\frac{\partial h_{2}}{\partial h_{1}}\frac{\partial h_{1}}{\partial w_{h}} whL3=y3L3h3y3whh3+y3L3h3y3h2h3whh2+y3L3h3y3h2h3h1h2whh1

观察上式,由于 h t , t ∈ ( 0 , T ) h_{t},t\in (0,T) ht,t(0,T)的存在,使得损失函数对参数求偏导的过程中存在大量的复合求导。再将上述等式推广到所有时间步,则有

∂ L ∂ w o = ∑ t = 0 T ∂ L t ∂ y t ∂ y t ∂ w o \frac{\partial L}{\partial w_{o}}=\sum_{t=0}^{T}\frac{\partial L_{t}}{\partial y_{t}}\frac{\partial y_{t}}{\partial w_{o}} woL=t=0TytLtwoyt

∂ L ∂ w i = ∑ t = 0 T ∑ j = 0 t ∂ L t ∂ y t ∂ y t ∂ h t ( ∏ k = j + 1 t ∂ h k ∂ h k − 1 ) ∂ h j ∂ w i \frac{\partial L}{\partial w_{i}}=\sum_{t=0}^{T}\sum_{j=0}^{t}\frac{\partial L_{t}}{\partial y_{t}}\frac{\partial y_{t}}{\partial h_{t}}(\prod_{k=j+1}^{t}\frac{\partial h_{k}}{\partial h_{k-1}})\frac{\partial h_{j}}{\partial w_{i}} wiL=t=0Tj=0tytLthtyt(k=j+1thk1hk)wihj

∂ L ∂ w h = ∑ t = 0 T ∑ j = 0 t ∂ L t ∂ y t ∂ y t ∂ h t ( ∏ k = j + 1 t ∂ h k ∂ h k − 1 ) ∂ h j ∂ w h \frac{\partial L}{\partial w_{h}}=\sum_{t=0}^{T}\sum_{j=0}^{t}\frac{\partial L_{t}}{\partial y_{t}}\frac{\partial y_{t}}{\partial h_{t}}(\prod_{k=j+1}^{t}\frac{\partial h_{k}}{\partial h_{k-1}})\frac{\partial h_{j}}{\partial w_{h}} whL=t=0Tj=0tytLthtyt(k=j+1thk1hk)whhj

推导到这里,RNN梯度消失和梯度爆炸的原因就产生了。上述的第二个和第三个等式中出现了与时间 t t t相关的连乘的因式,根据第二节中RNN工作原理的介绍,以第二个等式为例,

∂ h k ∂ h k − 1 = f ′ ⋅ w i \frac{\partial h_{k}}{\partial h_{k-1}}=f^{'}\cdot w_{i} hk1hk=fwi

其中 f ′ f^{'} f为激活函数的导数,以 s i g m o i d sigmoid sigmoid函数为例, f ∈ ( 0 , 1 ) f\in(0,1) f(0,1)其导数为 f ′ = f ( 1 − f ) ∈ ( 0 , 1 4 ) f^{'}=f(1-f)\in(0,\frac{1}{4}) f=f(1f)(0,41),则 w i < 1 w_{i}<1 wi<1时, ∂ h k ∂ h k − 1 < 1 \frac{\partial h_{k}}{\partial h_{k-1}}<1 hk1hk<1,经过数次相乘后, ∂ L ∂ w i \frac{\partial L}{\partial w_{i}} wiL逐渐接近于0,即梯度消失; w i > 4 w_{i}>4 wi>4时, ∂ h k ∂ h k − 1 > 1 \frac{\partial h_{k}}{\partial h_{k-1}}>1 hk1hk>1,经过数次相乘后, ∂ L ∂ w i \frac{\partial L}{\partial w_{i}} wiL越来越大,即梯度爆炸。

  至此,我们就从理论上分析了RNN中存在梯度消失和梯度爆炸的原因。但为了能够使用RNN利用历史信息的特性,对RNN的结构进行适当的改造就能得到性能更加优越的LSTM。LSTM的结构大大缓解了传统RNN中存在的梯度消失和梯度爆炸的问题,从而使时间步能够大大增长。具体的分析请参考下一篇文章。

  • 3
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值