LSTM如何解决梯度消失与梯度爆炸

在这里插入图片描述
  这是一张经典的LSTM示意图,LSTM依靠  f t f_t ft i t i_t it o t o_t ot来控制输入输出, f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_{t}=\sigma\left(W_{f} \cdot\left[h_{t-1}, x_{t}\right]+b_{f}\right) ft=σ(Wf[ht1,xt]+bf) i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_{t}=\sigma\left(W_{i} \cdot\left[h_{t-1}, x_{t}\right]+b_{i}\right) it=σ(Wi[ht1,xt]+bi) o t = σ ( W o [ h t − 1 , x t ] + b o ) o_{t}=\sigma\left(W_{o}\left[h_{t-1}, x_{t}\right]+b_{o}\right) ot=σ(Wo[ht1,xt]+bo)
  我们将其简化为: f t = σ ( W f X t + b f ) f_{t}=\sigma\left(W_{f} X_{t}+b_{f}\right) ft=σ(WfXt+bf) i t = σ ( W i X t + b i ) i_{t}=\sigma\left(W_{i} X_{t}+b_{i}\right) it=σ(WiXt+bi) o i = σ ( W o X t + b o ) o_{i}=\sigma\left(W_{o} X_{t}+b_{o}\right) oi=σ(WoXt+bo)
  当前的状态  S t = f t S t − 1 + i t X t S_{t}=f_{t} S_{t-1}+i_{t} X_{t} St=ftSt1+itXt 类似与传统RNN  S t = W s S t − 1 + W x X t + b 1 S_{t}=W_{s} S_{t-1}+W_{x} X_{t}+b_{1} St=WsSt1+WxXt+b1 。将LSTM的状态表达式展开后得: S t = σ ( W f X t + b f ) S t − 1 + σ ( W i X t + b i ) X t S_{t}=\sigma\left(W_{f} X_{t}+b_{f}\right) S_{t-1}+\sigma\left(W_{i} X_{t}+b_{i}\right) X_{t} St=σ(WfXt+bf)St1+σ(WiXt+bi)Xt  如果加上激活函数 S t = tanh ⁡ [ σ ( W f X t + b f ) S t − 1 + σ ( W i X t + b i ) X t ] S_{t}=\tanh \left[\sigma\left(W_{f} X_{t}+b_{f}\right) S_{t-1}+\sigma\left(W_{i} X_{t}+b_{i}\right) X_{t}\right] St=tanh[σ(WfXt+bf)St1+σ(WiXt+bi)Xt]  RNN梯度消失和爆炸的原因这篇文章中传统RNN求偏导的过程包含: ∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t tanh ⁡ ′ W s \prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}=\prod_{j=k+1}^{t} \tanh ^{\prime} W_{s} j=k+1tSj1Sj=j=k+1ttanhWs  对于LSTM同样也包含这样的一项,但是在LSTM中: ∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t tanh ⁡ ′ σ ( W f X t + b f ) \prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}=\prod_{j=k+1}^{t} \tanh ^{\prime} \sigma\left(W_{f} X_{t}+b_{f}\right) j=k+1tSj1Sj=j=k+1ttanhσ(WfXt+bf) 假设   Z = tanh ⁡ ′ ( x ) σ ( y ) Z=\tanh ^{\prime}(x) \sigma(y) Z=tanh(x)σ(y),则 Z Z Z的函数图像如下图所示:

在这里插入图片描述
  可以看到该函数值基本上不是0就是1。
  传统RNN的求偏导过程: ∂ L 3 ∂ W s = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ W s + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ W s + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ S 2 ∂ S 2 ∂ S 1 ∂ S 1 ∂ W s \frac{\partial L_{3}}{\partial W_{s}}=\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial S_{1}} \frac{\partial S_{1}}{\partial W_{s}} WsL3=O3L3S3O3WsS3+O3L3S3O3S2S3WsS2+O3L3S3O3S2S3S1S2WsS1
  在LSTM中为: ∂ L 3 ∂ W s = ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 3 ∂ W s + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 2 ∂ W s + ∂ L 3 ∂ O 3 ∂ O 3 ∂ S 3 ∂ S 1 ∂ W s \frac{\partial L_{3}}{\partial W_{s}}=\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{2}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{1}}{\partial W_{s}} WsL3=O3L3S3O3WsS3+O3L3S3O3WsS2+O3L3S3O3WsS1
  因为 ∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t tanh ⁡ ′ σ ( W f X t + b f ) ≈ 0 ∣ 1 \prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}=\prod_{j=k+1}^{t} \tanh ^{\prime} \sigma\left(W_{f} X_{t}+b_{f}\right) \approx 0 | 1 j=k+1tSj1Sj=j=k+1ttanhσ(WfXt+bf)01
  这样就解决了传统RNN中梯度消失的问题。

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值