算法笔记(四)梯度消失 梯度爆炸

前言

梯度爆炸和梯度消失问题都是因为网络太深,网络权值更新不稳定造成的,本质上是因为梯度反向传播中的连乘效应。

w1
w2
省略号
w4
X
a1
a2
an-1
an

前向传播:
z 1 = w 1 X + b 1 , a 1 = σ ( z 1 ) z 2 = w 2 a 1 + b 2 , a 2 = σ ( z 2 ) . . . z n = w n a n − 1 + b n , a n = σ ( z n ) \begin{aligned} z_1&=w_1X+b_1,a_1=\sigma (z_1)\\ z_2&=w_2a_1+b_2,a_2=\sigma(z_2)\\ ...\\ z_n&=w_na_{n-1+b_n},a_n=\sigma(z_n)\\ \end{aligned} z1z2...zn=w1X+b1,a1=σ(z1)=w2a1+b2,a2=σ(z2)=wnan1+bn,an=σ(zn)
则反向传播:
α l o s s α w 1 = α l o s s α a n α a n α z n α z n α a n − 1 α a n − 1 α z n − 1 α z n − 1 α a n − 2 α a n − 2 α z n − 2 . . . α a 1 α z 1 α z 1 α w 1 = α l o s s s α a n ⋅ σ ′ ( z n ) w n ⋅ σ ′ ( z n − 1 ) w n − 1 ⋅ . . . ⋅ σ ′ ( z 1 ) X \begin{aligned} \frac{\alpha loss}{\alpha w_1} &=\frac{\alpha loss}{\alpha a_n}\frac{\alpha a_n}{\alpha z_n}\frac{\alpha z_n}{\alpha a_{n-1}}\frac{\alpha a_{n-1}}{\alpha z_{n-1}}\frac{\alpha z_{n-1}}{\alpha a_{n-2}}\frac{\alpha a_{n-2}}{\alpha z_{n-2}}...\frac{\alpha a_1}{\alpha z_1}\frac{\alpha z_1}{\alpha w_1}\\ &=\frac{\alpha losss}{\alpha a_n}·\sigma'(z_n)w_n·\sigma'(z_{n-1})w_{n-1}·...·\sigma'(z_1)X \end{aligned} αw1αloss=αanαlossαznαanαan1αznαzn1αan1αan2αzn1αzn2αan2...αz1αa1αw1αz1=αanαlosssσ(zn)wnσ(zn1)wn1...σ(z1)X

  • 梯度消失:与激活函数的导数 σ ′ ( x ) \sigma^{'}(x) σ(x)有关。
    假如 σ \sigma σ为sigmoid激活函数,而sigmoid的导数范围是[0,0.25],"链式法则"的累乘会导致梯度趋于0.

  • 梯度爆炸:与权重有关,即 ∣ σ ′ ( z ) w ∣ > 1 |\sigma'(z) w|>1 σ(z)w>1
    链式法则还与 ∣ σ ′ ( z ) w ∣ |\sigma'(z) w| σ(z)w有关,如果该值>1,"链式法则"累乘后会导致梯度趋于非常大的值.

梯度消失

与梯度太小有关。表现为只在后层学习,浅层不学习,浅层梯度基本无,权重改变量小,收敛慢,训练速度慢。

原因:

  1. 采用了不适合的激活函数,导致链式法则累乘时被0影响。
  2. 模型在训练的过程中,会不断调整数据分布,有可能接近激活函数饱和区,此时的导数很小,难以调整权重。

解决办法:

  1. 使用BN,将数据分布归一化。
  2. 预训练,微调。
  3. 使用relu等激活函数。
  4. 使用残差结构。
  5. LSTM。
  6. 正则化。

梯度爆炸

与链式法则中的权重有关。可能导致权重NAN。
原因:

  1. 若初始化权重太大,累乘后会爆炸。
  2. 梯度>1。

解决办法:

  1. 注意权重初始化。
  2. 梯度剪裁。
  3. BN。
  4. 预训练,微调。

RNN为何会梯度消失/爆炸?

首先看RNN计算流程,简设3个timestep:
在这里插入图片描述

前向传播:
S 1 = W x X 1 + W s S 0 + b 1 S_1=W_xX_1+W_sS_0+b1 S1=WxX1+WsS0+b1 O 1 = W o S 1 + b 2 O_1=W_oS_1+b2 O1=WoS1+b2
S 2 = W x X 2 + W s S 1 + b 1 S_2=W_xX_2+W_sS_1+b1 S2=WxX2+WsS1+b1 O 2 = W o S 2 + b 2 O_2=W_oS_2+b2 O2=WoS2+b2
S 3 = W x X 3 + W s S 2 + b 1 S_3=W_xX_3+W_sS_2+b1 S3=WxX3+WsS2+b1 O 3 = W o S 3 + b 2 O_3=W_oS_3+b2 O3=WoS3+b2

此刻的损失函数: l o s s 3 = 1 2 ( Y 3 − O 3 ) 2 loss_3=\frac{1}{2}(Y_3-O_3)^2 loss3=21(Y3O3)2

反向传播:

需要对 W o W_o Wo W s W_s Ws W x W_x Wx求导,其中对 W s W_s Ws W x W_x Wx求导是同理的。

(1) δ l o s s 3 δ W o = δ l o s s 3 δ O 3 δ O 3 δ W o \frac{\delta loss_3}{\delta W_o}=\frac{\delta loss_3}{\delta O_3}\frac{\delta O_3}{\delta W_o} δWoδloss3=δO3δloss3δWoδO3

可以看出网络加深对于 W o W_o Wo无影响。

(2) δ l o s s 3 δ W s = δ l o s s 3 δ O 3 δ O 3 δ S 3 δ S 3 δ W s + δ l o s s 3 δ O 3 δ O 3 δ S 3 δ S 3 δ S 2 δ S 2 δ W s + δ l o s s 3 δ O 3 δ O 3 δ S 3 δ S 3 δ S 2 δ S 2 δ S 1 δ S 1 δ W s \frac{\delta loss_3}{\delta W_s}=\frac{\delta loss_3}{\delta O_3}\frac{\delta O_3}{\delta S_3}\frac{\delta S_3}{\delta W_s}+\frac{\delta loss_3}{\delta O_3}\frac{\delta O_3}{\delta S_3}\frac{\delta S_3}{\delta S_2}\frac{\delta S_2}{\delta W_s}+\frac{\delta loss_3}{\delta O_3}\frac{\delta O_3}{\delta S_3}\frac{\delta S_3}{\delta S_2}\frac{\delta S_2}{\delta S_1}\frac{\delta S_1}{\delta W_s} δWsδloss3=δO3δloss3δS3δO3δWsδS3+δO3δloss3δS3δO3δS2δS3δWsδS2+δO3δloss3δS3δO3δS2δS3δS1δS2δWsδS1

可以简写为:

δ l o s s t δ W s = ∑ k = 0 t δ l o s s t δ O t δ O t δ S t ∏ j = k + 1 t ( δ S j δ S j − 1 ) δ S k δ W x \frac{\delta loss_t}{\delta W_s}=\sum_{k=0}^t\frac{\delta loss_t}{\delta O_t}\frac{\delta O_t}{\delta S_t}\prod_{j=k+1}^t(\frac{\delta S_j}{\delta S_{j-1}})\frac{\delta S_k}{\delta W_x} δWsδlosst=k=0tδOtδlosstδStδOtj=k+1t(δSj1δSj)δWxδSk

其中连乘的 ∏ j = k + 1 t ( δ S j δ S j − 1 ) \prod_{j=k+1}^t(\frac{\delta S_j}{\delta S_{j-1}}) j=k+1t(δSj1δSj)是导致梯度爆炸和消失的问题所在

RNN梯度与其他网络梯度的区别

  1. MLP/CNN 中不同的层有不同的参数,各是各的梯度;而 RNN 中同样的权重在各个时间步共享,最终的梯度= 各个时间步的梯度的和。
  2. RNN 中总的梯度是不会消失的。即便梯度越传越弱,那也只是远距离的梯度消失,由于近距离的梯度不会消失,所有梯度之和便不会消失。RNN 所谓梯度消失的真正含义是,梯度被近距离梯度主导,导致模型难以学到远距离的依赖关系

LSTM如何缓解梯度消失/爆炸?

LSTM介绍

在这里插入图片描述

  1. 遗忘门
    • 可求得: f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t=\sigma (W_f·[h_{t-1},x_t]+b_f) ft=σ(Wf[ht1,xt]+bf).在这里插入图片描述
  2. 输入门
    可求得:
    • i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t=\sigma (W_i·[h_{t-1},x_t]+b_i) it=σ(Wi[ht1,xt]+bi).
    • C ^ t = t a n h ( W C ⋅ [ h t − 1 , x t ] + b C ) \hat C_t=tanh (W_C·[h_{t-1},x_t]+b_C) C^t=tanh(WC[ht1,xt]+bC).在这里插入图片描述
    • C t = f t ⋅ C t − 1 + i t ⋅ C ^ t C_t=f_t·C_{t-1}+i_t·\hat C_t Ct=ftCt1+itC^t
      在这里插入图片描述
  3. 输出门
    可求得:
    • O t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) O_t=\sigma (W_o·[h_{t-1},x_t]+b_o) Ot=σ(Wo[ht1,xt]+bo).在这里插入图片描述
    • h t = O t ⋅ t a n h ( C t ) h_t=O_t·tanh(C_t) ht=Ottanh(Ct).在这里插入图片描述

LSTM可以解决梯度消失,缓解梯度爆炸

整理可得公式:

  • f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t=\sigma (W_f·[h_{t-1},x_t]+b_f) ft=σ(Wf[ht1,xt]+bf).
  • i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t=\sigma (W_i·[h_{t-1},x_t]+b_i) it=σ(Wi[ht1,xt]+bi).
  • C ^ t = t a n h ( W c ⋅ [ h t − 1 , x t ] + b c ) \hat C_t=tanh (W_c·[h_{t-1},x_t]+b_c) C^t=tanh(Wc[ht1,xt]+bc).
  • C t = f t ⋅ C t − 1 + i t ⋅ C ^ t C_t=f_t·C_{t-1}+i_t·\hat C_t Ct=ftCt1+itC^t
  • O t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) O_t=\sigma (W_o·[h_{t-1},x_t]+b_o) Ot=σ(Wo[ht1,xt]+bo).
  • h t = O t ⋅ t a n h ( C t ) h_t=O_t·tanh(C_t) ht=Ottanh(Ct).
  1. LSTM 中梯度的传播有很多条路径, C t − 1 → C t = f t ⋅ c t − 1 + i t ⋅ c ^ t C_{t-1} \rightarrow C_t=f_t·c_{t-1}+i_t·\hat c_t Ct1Ct=ftct1+itc^t这条路径上只有逐元素相乘和相加的操作,梯度流最稳定;但是其他路径(例如 C t − 1 → h t − 1 → i t → c t C_{t-1} \rightarrow h_{t-1} \rightarrow i_t \rightarrow c_t Ct1ht1itct)上梯度流与普通 RNN 类似,照样会发生相同的权重矩阵反复连乘。根据上式可以看出 C t C_t Ct公式与 h t h_t ht, i t i_t it, C ^ t \hat C_t C^t, C t − 1 C_{t-1} Ct1有关,则可以得出:
    δ C ( k ) δ C ( k − 1 ) = δ C ( k ) δ f ( k ) δ f ( k ) δ h ( k − 1 ) δ h ( k − 1 ) δ C ( k − 1 ) [ h t 公 式 ] + δ C ( k ) δ i ( k ) δ i ( k ) δ h ( k − 1 ) δ h ( k − 1 ) δ C ( k − 1 ) [ i t 公 式 ] + δ C ( k ) δ C ^ ( k ) δ C ^ ( k ) δ h ( k − 1 ) δ h ( k − 1 ) δ C ( k − 1 ) [ C ^ t 公 式 ] + δ C ( k ) δ C ( k − 1 ) [ C t 公 式 ] = C t − 1 ( σ ′ ⋅ W f ) ( o t ⋅ t a n h ′ ) + C ^ t ( σ ′ ⋅ W i ) ( o t ⋅ t a n h ′ ) + i t ( t a n h ′ ⋅ W c ) ( o t ⋅ t a n h ′ ) + f t \begin{aligned} \frac{\delta C^{(k)}}{\delta C^{(k-1)}} &=\frac{\delta C^{(k)}}{\delta f^{(k)}}\frac{\delta f^{(k)}}{\delta h^{(k-1)}}\frac{\delta h^{(k-1)}}{\delta C^{(k-1)}}[h_t公式]\\ &+\frac{\delta C^{(k)}}{\delta i^{(k)}}\frac{\delta i^{(k)}}{\delta h^{(k-1)}}\frac{\delta h^{(k-1)}}{\delta C^{(k-1)}}[i_t公式]\\ &+\frac{\delta C^{(k)}}{\delta \hat C^{(k)}}\frac{\delta \hat C^{(k)}}{\delta h^{(k-1)}}\frac{\delta h^{(k-1)}}{\delta C^{(k-1)}}[\hat C_t公式]\\ &+\frac{\delta C^{(k)}}{\delta C^{(k-1)}} [C_t公式]\\ &=C^{t-1}(\sigma'·W_f)(o^t·tanh')\\ &+\hat C^{t}(\sigma'·W_i)(o^t·tanh')\\ &+i^{t}(tanh'·W_c)(o^t·tanh')\\ &+f_t \end{aligned} δC(k1)δC(k)=δf(k)δC(k)δh(k1)δf(k)δC(k1)δh(k1)[ht]+δi(k)δC(k)δh(k1)δi(k)δC(k1)δh(k1)[it]+δC^(k)δC(k)δh(k1)δC^(k)δC(k1)δh(k1)[C^t]+δC(k1)δC(k)[Ct]=Ct1(σWf)(ottanh)+C^t(σWi)(ottanh)+it(tanhWc)(ottanh)+ft
    因此RNN的问题 ∏ j = k t \prod_{j=k}^t j=kt在LSTM中等价于 ( f k ⋅ f k + 1 ⋅ f 2 ⋅ . . . ⋅ f t ) + o t h e r (f^{k}·f^{k+1}·f^{2}·...·f^{t})+other (fkfk+1f2...ft)+other

  2. 正常梯度 + 消失梯度 = 正常梯度,总的远距离梯度就不会消失,因此 LSTM 可以解决梯度消失。

    • 可自主选择[0,1]之间,当遗忘门接近 1时(例如模型初始化时会把 forget bias 设置成较大的正数,让遗忘门饱和),这时候远距离梯度不消失;
    • 当遗忘门接近 0时,但这时模型是故意阻断梯度流的(例如情感分析任务中有一条样本 “A,但是 B”,模型读到“但是”后选择把遗忘门设置成 0,遗忘掉内容 A,这是合理的)。
  3. 正常梯度 + 爆炸梯度 = 爆炸梯度,因此 LSTM 仍然有可能发生梯度爆炸。不过,由于 LSTM 和普通 RNN 相比多经过了很多次激活函数(导数都小于 1),因此 LSTM 发生梯度爆炸的频率要低得多。

参考

https://zhuanlan.zhihu.com/p/25631496
https://www.cnblogs.com/bonelee/p/10475453.html
https://www.zhihu.com/question/34878706/answer/665429718

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

nooobme

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值