rnn梯度弥散 LSTM无梯度弥散

之前看过,现在突然想不起,真的是好记性不如烂笔头,希望大家在看的时候能够拿笔和纸跟着推导一遍,加深理解。

转自:https://zhuanlan.zhihu.com/p/28687529 https://zhuanlan.zhihu.com/p/28749444

1.RNN梯度弥散和爆炸的原因

经典的RNN结构如下图所示:
在这里插入图片描述
假设我们的时间序列只有三段, S 0 S_0 S0为给定值,神经元没有激活函数,则RNN最简单的前向传播过程如下:
在这里插入图片描述
假设在t=3时刻,损失函数为 L 3 = 1 2 ( Y 3 − O 3 ) 2 L_3=\frac{1}{2}(Y_3-O_3)^2 L3=21(Y3O3)2

则对于一次训练任务的损失函数为 L = ∑ t = 1 T L t L=\sum_{t=1}^{T}L_t L=t=1TLt,即每一时刻损失值的累加。

使用随机梯度下降法训练RNN其实就是对 W x W_x Wx W s W_s Ws W 0 W_0 W0 以及 b 1 b_1 b1 b 2 b_2 b2 求偏导,并不断调整它们以使L尽可能达到最小的过程。

现在假设我们我们的时间序列只有三段,t1,t2,t3。

我们只对t3时刻的 [公式] 求偏导(其他时刻类似):
在这里插入图片描述
可以看出对于 W 0 W_0 W0 求偏导并没有长期依赖,但是对于 W x W_x Wx W s W_s Ws求偏导,会随着时间序列产生长期依赖。因为 S t S_t St随着时间序列向前传播,而 S t S_t St又是 W x W_x Wx W s W_s Ws的函数。

根据上述求偏导的过程,我们可以得出任意时刻对 W x W_x Wx W s W_s Ws求偏导的公式:
在这里插入图片描述

任意时刻对 W s W_s Ws 求偏导的公式同上。

如果加上激活函数, S j = t a n h ( W x X j + W s S j − 1 + b 1 ) S_j=tanh(W_xX_j+W_sS_{j-1}+b_1) Sj=tanh(WxXj+WsSj1+b1)
∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t t a n h ′ W s \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}=\prod_{j=k+1}^{t}tanh^{'}W_s j=k+1tSj1Sj=j=k+1ttanhWs
激活函数tanh和它的导数图像如下。

在这里插入图片描述

由上图可以看出 t a n h ′ ≤ 1 tanh^{'} \leq1 tanh1,对于训练过程大部分情况下tanh的导数是小于1的,因为很少情况下会出现 W x X j + W s S j − 1 + b 1 = 0 W_xX_j+W_sS_{j-1}+b_1=0 WxXj+WsSj1+b1=0,如果 W s W_s Ws 也是一个大于0小于1的值,则当t很大时 ∏ j = k + 1 t t a n h ′ W s \prod_{j=k+1}^{t}tanh^{'}W_s j=k+1ttanhWs,就会趋近于0,和 0.01^{50} 趋近与0是一个道理。同理当 W s W_s Ws很大时 ∏ j = k + 1 t t a n h ′ W s \prod_{j=k+1}^{t}tanh^{'}W_s j=k+1ttanhWs 就会趋近于无穷,这就是RNN中梯度消失和爆炸的原因。

至于怎么避免这种现象,让我在看看 ∂ L t ∂ W x = ∑ k = 0 t ∂ L t ∂ O t ∂ O t ∂ S t ( ∏ j = k + 1 t ∂ S j ∂ S j − 1 ) ∂ S k ∂ W x \frac{\partial{L_t}}{\partial{W_x}}=\sum_{k=0}^{t}\frac{\partial{L_t}}{\partial{O_t}}\frac{\partial{O_t}}{\partial{S_t}}(\prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}})\frac{\partial{S_k}}{\partial{W_x}} WxLt=k=0tOtLtStOt(j=k+1tSj1Sj)WxSk梯度消失和爆炸的根本原因就是 ∏ j = k + 1 t ∂ S j ∂ S j − 1 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}} j=k+1tSj1Sj这一坨,要消除这种情况就需要把这一坨在求偏导的过程中去掉,至于怎么去掉,一种办法就是使 ∏ j = k + 1 t ∂ S j ∂ S j − 1 ≈ 1 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}\approx1 j=k+1tSj1Sj1另一种办法就是使 ∏ j = k + 1 t ∂ S j ∂ S j − 1 ≈ 0 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}\approx0 j=k+1tSj1Sj0 。其实这就是LSTM做的事情,至于细节问题下节将进行介绍。

2.LSTM如何解决梯度消失问题

先上一张LSTM的经典图:
在这里插入图片描述

而LSTM可以抽象成这样:

在这里插入图片描述

三个×分别代表的就是forget gate,input gate,output gate,而我认为LSTM最关键的就是forget gate这个部件。这三个gate是如何控制流入流出的呢,其实就是通过下面 f t , i t , o t f_t,i_t,o_t ft,it,ot 三个函数来控制,因为 σ ( x ) \sigma(x) σ(x)(代表sigmoid函数) 的值是介于0到1之间的,刚好用趋近于0时表示流入不能通过gate,趋近于1时表示流入可以通过gate。
在这里插入图片描述
当前的状态 S t = f t S t − 1 + i t X t S_t=f_tS_{t-1}+i_tX_t St=ftSt1+itXt类似与传统RNN S t = W s S t − 1 + W x X t + b 1 S_t=W_sS_{t-1}+W_xX_t+b_1 St=WsSt1+WxXt+b1。将LSTM的状态表达式展开后得:
在这里插入图片描述
如果加上激活函数, S t = t a n h [ σ ( W f X t + b f ) S t − 1 + σ ( W i X t + b i ) X t ] S_t=tanh[\sigma(W_fX_t+b_f)S_{t-1}+\sigma(W_iX_t+b_i)X_t] St=tanh[σ(WfXt+bf)St1+σ(WiXt+bi)Xt]

这篇文章中传统RNN求偏导的过程包含
在这里插入图片描述

对于LSTM同样也包含这样的一项,但是在LSTM中

在这里插入图片描述
假设 Z = t a n h ′ ( x ) σ ( y ) Z=tanh'(x)\sigma(y) Z=tanh(x)σ(y) ,则 Z Z Z的函数图像如下图所示:
在这里插入图片描述
可以看到该函数值基本上不是0就是1。

这篇文章中传统RNN的求偏导过程:
在这里插入图片描述

如果在LSTM中上式可能就会变成:
在这里插入图片描述

因为 ∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t t a n h ′ σ ( W f X t + b f ) ≈ 0 ∣ 1 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}=\prod_{j=k+1}^{t}tanh'\sigma(W_fX_t+b_f)\approx0|1 j=k+1tSj1Sj=j=k+1ttanhσ(WfXt+bf)01 ,这样就解决了传统RNN中梯度消失的问题。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值