RNN的梯度消失与可以看懂的解决方法

本文参考了李宏毅机器学习视频RNN梯度消失和爆炸的原因LSTM如何解决梯度消失问题(很清晰但要梯子)神经网络中存在的问题LSTM原理详解(这篇也不错但是有点点拼拼凑凑)


【NLP基础理论】RNN 中简单介绍了RNN的基础理论,如果对RNN还没有认知的可以先看这篇。



LSTM现在都已经成为一个标准RNN,大家说RNN多半指的是LSTM,而最开始的RNN多称为 Simple RNN。所以本文主要是对于SimpleRNN为什么会存在 梯度消失/爆炸问题进行说明。

RNN 随着epoch增加,通常情况下是 loss 越来越小(下图蓝色),但是有的时候 loss 抖动的特别厉害,然后爆掉(下图绿色),这就是发生了梯度消失/爆炸。
在这里插入图片描述

直观感受梯度消失和爆炸(特例)

下方是一个极其简单的RNN,1000 个输入(时间步),除了第一个输入为 1 以外,其余均为 0。中间RNN函数为线性函数 y = x,没有bias,输入和输出的参数都为 1,输入给下一个 RNN 的参数为 w。那么根据公式
s i = w s i − 1 + 1 ⋅ x i y i = σ ( 1 ⋅ s i ) s_i=ws_{i-1}+1\cdot x_i \\ y_i=\sigma(1\cdot s_i) si=wsi1+1xiyi=σ(1si)
可以得到最后 y 1000 y_{1000} y1000 的输出为 w 999 w^{999} w999
在这里插入图片描述

当我们尝试去改变参数 w w w 的时候,我们会发现,只要参数有一点变化,对于长句子(长输入)来说,最后一向的输出会有很大的变化!
比如 w w w 从 1 变成 1.01,那么 最后的输出会从 1 变成约 20000,当 w w w 从 1 变成小于 1 的数字时,最后输出直接约等于 0。
梯度,我们也可以理解为 直观地看到参数 w w w 的变化会引起最后输出多大的变化。
在这里插入图片描述
因此上图中绿色部分 ∂ L ∂ w \frac{\partial L}{\partial w} wL 值较大,而黄色部分值较小。但我们可以通过对于梯度较大的地方用较小的learning rate来对冲,反之亦然。可是这很难解决,因为RNN这个性质会导致梯度时大时小,error surface就会非常崎岖。

数学感受梯度消失和梯度爆炸

简单回忆 反向传播(BP) 的流程:

这是一个简单的神经网络,其中 θ 21 ( 2 ) \theta_{21}^{(2)} θ21(2) 代表第 2 层参数链接了前面第 2 个神经元和后面第 1 个神经元的参数; z   1 1 , a   1 1 z^1_{\ 1},a^1_{\ 1} z 11,a 11分别表示第一层第一个激活函数的输入值和输出值。此处所有激活函数为sigmoid函数。
在这里插入图片描述

  • 误差函数 E = 1 2 ( y − y ^ ) 2 E=\frac{1}{2}(y-\hat{y})^2 E=21(yy^)2
  • 参数迭代公式
    θ j k ( l ) ← θ j k ( l ) + Δ θ j k ( l ) w h e r e , θ j k ( l ) = − η ∂ E ∂ θ j k ( l ) = η δ k ( l ) a j ( l − 1 ) \theta_{jk}^{(l)} \leftarrow \theta_{jk}^{(l)} + \Delta \theta_{jk}^{(l)}\\ where,\theta_{jk}^{(l)}=-\eta\frac{\partial E}{\partial \theta_{jk}^{(l)}} = \eta \delta^{(l)}_ka^{(l-1)}_j θjk(l)θjk(l)+Δθjk(l)where,θjk(l)=ηθjk(l)E=ηδk(l)aj(l1)
  • θ 11 ( 2 ) \theta^{(2)}_{11} θ11(2) 为例,最后一层参数更新公式为:
    E = 1 2 ( y − a 1 ( 2 ) ) 2 δ k ( l ) = ( 1 − σ ( z k ( l ) ) ) σ ( z k ( l ) ) ( y − a k ( l ) ) E=\frac{1}{2}(y-a_1^{(2)})^2\\ \delta_k^{(l)}=(1-\sigma(z^{(l)}_k))\sigma(z^{(l)}_k)(y-a_k^{(l)}) E=21(ya1(2))2δk(l)=(1σ(zk(l)))σ(zk(l))(yak(l))
    推导: ∂ E ∂ θ 11 ( 2 ) = ∂ E ∂ a   1 2 ⋅ ∂ a   1 2 ∂ z   1 2 ⋅ ∂ z   1 2 ∂ θ 11 ( 2 ) \frac{\partial E}{\partial \theta_{11}^{(2)}}=\frac{\partial E}{\partial a^2_{\ 1}}\cdot \frac{\partial a^2_{\ 1}}{\partial z^2_{\ 1}}\cdot \frac{\partial z^2_{\ 1}}{\partial \theta_{11}^{(2)}} θ11(2)E=a 12Ez 12a 12θ11(2)z 12
    其中 E = 1 2 ( y − a   1 2 ) 2 → ∂ E ∂ a   1 2 = − y + a   1 2 a   1 2 = 1 1 + e − z   1 2 → ∂ a   1 2 ∂ z   1 2 = e − z   1 2 ( 1 + e − z   1 2 ) 2 = ( 1 − σ ( z   1 2 ) ) σ ( z   1 2 ) z   1 2 = a   0 1 ⋅ θ 01 ( 2 ) + a   1 1 ⋅ θ 11 ( 2 ) + a   2 1 ⋅ θ 21 ( 2 ) → ∂ z   1 2 ∂ θ 11 ( 2 ) = a   1 1 E=\frac{1}{2}(y-a^2_{\ 1})^2 \rightarrow \frac{\partial E}{\partial a^2_{\ 1}}=-y+a^2_{\ 1}\\ a^2_{\ 1}=\frac{1}{1+e^{-z^2_{\ 1}}} \rightarrow \frac{\partial a^2_{\ 1}}{\partial z^2_{\ 1}}=\frac{e^{-z^2_{\ 1}}}{(1+e^{-z^2_{\ 1}})^2}=(1-\sigma(z^2_{\ 1}))\sigma(z^2_{\ 1})\\ z^2_{\ 1}=a^1_{\ 0} \cdot \theta_{01}^{(2)}+a^1_{\ 1} \cdot \theta_{11}^{(2)}+a^1_{\ 2}\cdot \theta_{21}^{(2)} \rightarrow \frac{\partial z^2_{\ 1}}{\partial \theta_{11}^{(2)}}=a^1_{\ 1} E=21(ya 12)2a 12E=y+a 12a 12=1+ez 121z 12a 12=(1+ez 12)2ez 12=(1σ(z 12))σ(z 12)z 12=a 01θ01(2)+a 11θ11(2)+a 21θ21(2)θ11(2)z 12=a 11
  • θ 11 ( 1 ) \theta ^{(1)}_{11} θ11(1) 其它层的参数更新公式为:
    δ k ( l ) = ( 1 − σ ( z k ( l ) ) ) σ ( z k ( l ) ) θ k 1 ( l + 1 ) δ 1 ( l + 1 ) \delta_k^{(l)}=(1-\sigma(z^{(l)}_k))\sigma(z^{(l)}_k)\theta_{k1}^{(l+1)}\delta_1^{(l+1)} δk(l)=(1σ(zk(l)))σ(zk(l))θk1(l+1)δ1(l+1)
    推导: ∂ E ∂ θ 11 ( 2 ) = ∂ E ∂ a   1 2 ⋅ ∂ a   1 2 ∂ z   1 2 ⋅ ∂ z   1 2 ∂ a   1 2 ⋅ ∂ a   1 1 ∂ z   1 1 ⋅ ∂ z   1 1 ∂ θ 11 ( 1 ) \frac{\partial E}{\partial \theta_{11}^{(2)}}=\frac{\partial E}{\partial a^2_{\ 1}}\cdot \frac{\partial a^2_{\ 1}}{\partial z^2_{\ 1}}\cdot \frac{\partial z^2_{\ 1}}{\partial a^2_{\ 1}}\cdot \frac{\partial a^1_{\ 1}}{\partial z^1_{\ 1}}\cdot \frac{\partial z^1_{\ 1}}{\partial \theta ^{(1)}_{11}} θ11(2)E=a 12Ez 12a 12a 12z 12z 11a 11θ11(1)z 11
    这个推导没有出现一个神经元会有两个 E E E 共同带来的误差。如果有多个的话,这要把每个 E E E 的值相加即可。

题外话:基本上所有的均方误差损失函数(MSE) 都是以下形式表现: L ( Y ∣ f ( x ) ) = 1 N ∑ i = 1 N ( Y i − f ( x i ) ) 2 L(Y|f(x))=\frac{1}{N}\sum_{i=1}^{N}(Y_i-f(x_i))^2 L(Yf(x))=N1i=1N(Yif(xi))2
但后来突然发现还有将 N N N变为 2 N 2N 2N作为分母使用的(参考微软github
题外话结束

简单回忆 SimpleRNN 模型:

S i = t a n h ( W s S i − 1 + W x x i + b 1 ) O i = σ ( W o S i + b 2 ) S_i = tanh(W_sS_{i-1}+W_xx_i+b_1)\\ O_i = \sigma (W_oS_i+b_2) Si=tanh(WsSi1+Wxxi+b1)Oi=σ(WoSi+b2)
在这里插入图片描述
好的,回忆完了,RNN会有多个输出,这里如果要使用BP的话一定是要考虑到时间的,因此RNN的BP叫做时间反向传播-BPTT(Back Propagation Through Time)

开始BPTT

参考链接
再从最简单的例子开始,下图是只有三个输入的RNN,没有任何激活函数:
s 1 = w x x 1 + w s s 0 + b 1 , y 1 = w y s 1 + b 2 s 2 = w x x 2 + w s s 1 + b 1 , y 2 = w y s 2 + b 2 s 3 = w x x 2 + w s s 2 + b 1 , y 2 = w y s 3 + b 2 s_1 = w_xx_1+w_ss_0+b_1,y_1=w_ys_1+b_2\\ s_2 = w_xx_2+w_ss_1+b_1,y_2=w_ys_2+b_2\\ s_3 = w_xx_2+w_ss_2+b_1,y_2=w_ys_3+b_2 s1=wxx1+wss0+b1,y1=wys1+b2s2=wxx2+wss1+b1,y2=wys2+b2s3=wxx2+wss2+b1,y2=wys3+b2
在这里插入图片描述
那在 t = 3 t=3 t=3 的时刻,误差函数为 E 3 = 1 2 ( Y 3 − y 3 ) 2 E_3=\frac{1}{2}(Y_3-y_3)^2 E3=21(Y3y3)2
一次训练下所有的误差值为单个误差之和:
E = ∑ t = 0 T E t E = \sum_{t=0}^{T}E_t E=t=0TEt
而我们的目标是去更新所有的参数 w x , w s , w y w_x,w_s,w_y wx,ws,wy,所以需要计算误差项的梯度。
RNN中误差项的梯度并更新参数:
∂ E ∂ W = ∑ t = 1 T ∂ E t ∂ W W ← W − η ∂ E ∂ W \frac{\partial E}{\partial W} = \sum_{t=1}^T\frac{\partial E_t}{\partial W}\\ W \leftarrow W - \eta \frac{\partial E}{\partial W} WE=t=1TWEtWWηWE
此处,我们对 t = 1 t=1 t=1 时刻入手开始更新:
∂ E 1 ∂ w y = ∂ E 1 ∂ y 1 ∂ y 1 ∂ w y ∂ E 1 ∂ w x = ∂ E 1 ∂ y 1 ∂ y 1 ∂ s 1 ∂ s 1 ∂ w x ∂ E 1 ∂ w s = ∂ E 1 ∂ y 1 ∂ y 1 ∂ s 1 ∂ s 1 ∂ w s \frac{\partial E_1}{\partial w_y}=\frac{\partial E_1}{\partial y_1}\frac{\partial y_1}{\partial w_y}\\ \frac{\partial E_1}{\partial w_x}=\frac{\partial E_1}{\partial y_1}\frac{\partial y_1}{\partial s_1}\frac{\partial s_1}{\partial w_x}\\ \frac{\partial E_1}{\partial w_s}=\frac{\partial E_1}{\partial y_1}\frac{\partial y_1}{\partial s_1}\frac{\partial s_1}{\partial w_s}\\ wyE1=y1E1wyy1wxE1=y1E1s1y1wxs1wsE1=y1E1s1y1wss1
如果从 t=3 时刻开始,那么就需要每一次向后传递时,分一部分给 w x w_x wx再分一部分错误给后面。
在这里插入图片描述

∂ E ∂ w x = ∂ E 3 ∂ y 3 ∂ y 3 ∂ s 3 ∂ s 3 ∂ w x + ∂ E 3 ∂ y 3 ∂ y 3 ∂ s 3 ∂ s 3 ∂ s 2 ∂ s 2 ∂ w x + ∂ E 3 ∂ y 3 ∂ y 3 ∂ s 3 ∂ s 3 ∂ s 2 ∂ s 2 ∂ s 1 ∂ s 1 ∂ w x ∂ E ∂ w s = ∂ E 3 ∂ y 3 ∂ y 3 ∂ s 1 ∂ s 3 ∂ w s + ∂ E 3 ∂ y 3 ∂ y 3 ∂ s 3 ∂ s 3 ∂ s 2 ∂ s 2 ∂ w s + ∂ E 3 ∂ y 3 ∂ y 3 ∂ s 3 ∂ s 3 ∂ s 2 ∂ s 2 ∂ s 1 ∂ s 1 ∂ w s ( 因 为 s 3 是 包 含 了 s 2 和 s 1 的 ) \frac{\partial E}{\partial w_x}=\frac{\partial E_3}{\partial y_3}\frac{\partial y_3}{\partial s_3}\frac{\partial s_3}{\partial w_x}+\frac{\partial E_3}{\partial y_3}\frac{\partial y_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial w_x}+\frac{\partial E_3}{\partial y_3}\frac{\partial y_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial s_1}\frac{\partial s_1}{\partial w_x}\\ \frac{\partial E}{\partial w_s}=\frac{\partial E_3}{\partial y_3}\frac{\partial y_3}{\partial s_1}\frac{\partial s_3}{\partial w_s}+\frac{\partial E_3}{\partial y_3}\frac{\partial y_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial w_s}+\frac{\partial E_3}{\partial y_3}\frac{\partial y_3}{\partial s_3}\frac{\partial s_3}{\partial s_2}\frac{\partial s_2}{\partial s_1}\frac{\partial s_1}{\partial w_s}\\(因为s_3是包含了s_2和s_1的) wxE=y3E3s3y3wxs3+y3E3s3y3s2s3wxs2+y3E3s3y3s2s3s1s2wxs1wsE=y3E3s1y3wss3+y3E3s3y3s2s3wss2+y3E3s3y3s2s3s1s2wss1(s3s2s1)
对上述偏导公式进行总结,得出所有时刻的梯度之和
∂ E ∂ w x = ∑ k = 0 t ∂ E t ∂ y t ∂ y t ∂ s t ( ∏ j = k + 1 t ∂ s j ∂ s j − 1 ) ∂ s k ∂ w x \frac{\partial E}{\partial w_x}=\sum_{k=0}^t\frac{\partial E_t}{\partial y_t}\frac{\partial y_t}{\partial s_t}(\prod_{j=k+1}^t\frac{\partial s_j}{\partial s_{j-1}})\frac{\partial s_k}{\partial w_x} wxE=k=0tytEtstyt(j=k+1tsj1sj)wxsk
w s w_s ws 同上。
因为上述式子是假设没有任何激活函数,下式是任意时刻的梯度传递到时间步1时候的公式
∂ E k ∂ w x = ∂ E k ∂ y k ∂ y k ∂ s k ( ∏ t = 2 k ∂ s t ∂ s t − 1 ) ∂ s 1 ∂ w x \frac{\partial E_k}{\partial w_x}=\frac{\partial E_k}{\partial y_k}\frac{\partial y_k}{\partial s_k}(\prod_{t=2}^k\frac{\partial s_t}{\partial s_{t-1}})\frac{\partial s_1}{\partial w_x} wxEk=ykEkskyk(t=2kst1st)wxs1
因此,
没有任何激活函数的情况下 ∏ j = k + 1 t ∂ s j ∂ s j − 1 \prod_{j=k+1}^t\frac{\partial s_j}{\partial s_{j-1}} j=k+1tsj1sj t − k − 1 t-k-1 tk1 w s w_s ws相乘。那么 w s w_s ws的大小就会影响梯度爆炸还是消失。
有激活函数 : s j = t a n h ( w s s j − 1 + w x x j ) s_j=tanh(w_ss_{j-1}+w_xx_j) sj=tanh(wssj1+wxxj)
求偏导则先求 t a n h ( x ) tanh(x) tanh(x)的导,再求 x x x 的导: ∂ s j ∂ s j − 1 = t a n h ′ ( w s s j − 1 + w x x j ) ⋅ ∂ ∂ s j − 1 [ w s s j − 1 + w x x j ] = t a n h ′ ( w s s j − 1 + w x x j ) ⋅ w s \frac{\partial s_j}{\partial s_{j-1}}=tanh'(w_ss_{j-1}+w_xx_j)\cdot \frac{\partial}{\partial s_{j-1}}[w_ss_{j-1}+w_xx_j]\\ =tanh'(w_ss_{j-1}+w_xx_j)\cdot w_s sj1sj=tanh(wssj1+wxxj)sj1[wssj1+wxxj]=tanh(wssj1+wxxj)ws
在这里插入图片描述

t a n h ′ tanh' tanh 取值在[0,1],后面还是有 w s w_s ws !可恶。
(补充: 激活函数 t a n h ( x ) = e x − e − x e x + e − x tanh(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}} tanh(x)=ex+exexex)

重点!!!
所以到这里就可以发现,只要我时间步长够长,就会有越来越多的 w s w_s ws 相乘,如果说 w s w_s ws 正好也在 [0,1] ,配合上 t a n h ′ tanh' tanh,我们的梯度就消失了(所以我们可以认为是 t a n h ′ tanh' tanh w s w_s ws两者一起导致梯度消失)。
但当然,因为变量终究还是 w s w_s ws (可能会大于1,但 t a n h ′ tanh' tanh肯定小于等于1),如果 w s w_s ws 大的能抵消掉 那么多个 t a n h ′ tanh' tanh相乘 的时候,就会造成梯度爆炸

如果梯度消失,那么
∂ E ∂ W = ∑ t = 1 T ∂ E t ∂ W → 0 \frac{\partial E}{\partial W} = \sum_{t=1}^T\frac{\partial E_t}{\partial W} \rightarrow 0 WE=t=1TWEt0
这也就导致我们的参数在合理的时间内就没怎么更新了。
W ← W − η ∂ E ∂ W ≈ W W \leftarrow W - \eta \frac{\partial E}{\partial W} \approx W WWηWEW

干掉它:)

(参考链接)

现在知道了梯度消失和爆炸的问题就在于 ∂ E t ∂ w x = ∑ k = 0 t ∂ E t ∂ y t ∂ y t ∂ s t ( ∏ j = k + 1 t ∂ s j ∂ s j − 1 ) ∂ s k ∂ w x \frac{\partial E_t}{\partial w_x}=\sum_{k=0}^t\frac{\partial E_t}{\partial y_t}\frac{\partial y_t}{\partial s_t}(\prod_{j=k+1}^t\frac{\partial s_j}{\partial s_{j-1}})\frac{\partial s_k}{\partial w_x} wxEt=k=0tytEtstyt(j=k+1tsj1sj)wxsk
中的 ∏ j = k + 1 t ∂ s j ∂ s j − 1 \prod_{j=k+1}^t\frac{\partial s_j}{\partial s_{j-1}} j=k+1tsj1sj,最直观的想法就是让它乘来乘去一直约为 1 或者 一直约为 0 ,这样就不会对整体的梯度有很大的影响。

LSTM 可以解决。(Clockwise RNN 和 SCRN 也可以,但这里不讲了)

简单回忆LSTM

(一个非常详细的LSTM介绍)
下面是三个时间步长的LSTM, t t t时的输入是 [ h t − 1 , x t ] [h_{t-1},x_{t}] [ht1,xt] 当前输入与上一个的输出相结合。三个橙色的 σ \sigma σ 函数是LSTM的三个gate,从左至右分别为遗忘门、输入门以及输出门。
在这里插入图片描述

遗忘门 根据新的输入,经过激活函数之后,得到一个 0 - 1 的值,这个值决定了过去的记忆 c t − 1 c_{t-1} ct1 有多少被保留。:

f t = σ ( W f ⋅ [ h t − 1 , x t ] ) f_t = \sigma (W_f\cdot [h_{t-1},x_t]) ft=σ(Wf[ht1,xt])在这里插入图片描述

输入门 控制了新输入有多少要被加入到记忆中,但这里的输出还需要配合上神经网络觉得新输入中有用的部分。
t a n h ( W c ⋅ [ h t − 1 , x t ] ) ∗ σ ( W i ⋅ [ h t − 1 , x t ] ) tanh(W_c\cdot [h_{t-1},x_t])*\sigma(W_i \cdot [h_{t-1},x_t]) tanh(Wc[ht1,xt])σ(Wi[ht1,xt])
这里会拆成以下两个矩阵进行对应相乘:
c t ~ = t a n h ( W c ⋅ [ h t − 1 , x t ] ) i t = σ ( W i ⋅ [ h t − 1 , x t ] ) \tilde{c_t} = tanh(W_c \cdot [h_{t-1},x_t])\\ i_t=\sigma(W_i \cdot[h_{t-1},x_t]) ct~=tanh(Wc[ht1,xt])it=σ(Wi[ht1,xt])在这里插入图片描述

输出门 控制了有多少信息会被编辑到记忆细胞,作为下一个时间步的输入。
h t = σ ( W o ⋅ [ h t − 1 , x t ] ) ∗ t a n h ( c t ) h_t = \sigma(W_o \cdot [h_{t-1},x_t])*tanh(c_t) ht=σ(Wo[ht1,xt])tanh(ct)
在这里插入图片描述
记忆细胞 的内容就根据上述结果进行更新:
c t = c t − 1 ∗ f t + c t ~ ∗ i t c_t=c_{t-1} * f_t + \tilde{c_t}*i_t ct=ct1ft+ct~it

题外话:看到这里你会发现为什么LSTM会用到两个 t a n h tanh tanh 函数?
找了半天就这个链接点赞第二多的比较有道理(点赞最多的那个我实在不懂tanh怎么就有他说的多阶导能很长时间不为0的表现)。
大概的意思就是 t a n h tanh tanh的值在 [-1,1] 之间,但 sigmoid 在 [0,1] 之间,所以使用 t a n h tanh tanh 可以生成负值。我了解到在神经网络中,像 t a n h tanh tanh 这样的0中心激活函数可以加快收敛速度(关于0中心这篇不错),那可能也是这个原因我们在这里使用 t a n h tanh tanh,应该是其他激活函数也可用(但这里不是百分百保证,最好对于激活函数的区别再了解一下)。
题外话结束

LSTM中的BPTT

总结一下 LSTM 里发生的公式们( ⋅ \cdot 表示矩阵相乘, ∗ * 表示矩阵元素对应相乘):
f t = σ ( W f ⋅ [ h t − 1 , x t ] ) c t ~ = t a n h ( W c ⋅ [ h t − 1 , x t ] ) i t = σ ( W i ⋅ [ h t − 1 , x t ] ) o t = σ ( W o ⋅ [ h t − 1 , x t ] ) h t = o t ∗ t a n h ( c t ) , 输 出 c t = c t − 1 ∗ f t + c t ~ ∗ i t , 记 忆 更 新 f_t = \sigma (W_f\cdot [h_{t-1},x_t])\\\tilde{c_t} = tanh(W_c \cdot [h_{t-1},x_t])\\ i_t=\sigma(W_i \cdot[h_{t-1},x_t])\\ o_t = \sigma(W_o \cdot [h_{t-1},x_t])\\ h_t = o_t*tanh(c_t) ,输出 \\c_t=c_{t-1} * f_t + \tilde{c_t}*i_t,记忆更新 ft=σ(Wf[ht1,xt])ct~=tanh(Wc[ht1,xt])it=σ(Wi[ht1,xt])ot=σ(Wo[ht1,xt])ht=ottanh(ct)ct=ct1ft+ct~it

现在我们放一个时间步长为3的 LSTM :
在这里插入图片描述
列出涉及到的式子(嵌套的就不打开了,太多太乱了):
c 1 = c 0 ∗ f 1 + t a n h ( W c ⋅ [ h 0 , x 1 ] ) ∗ i 1 h 1 = o 1 ∗ t a n h ( c 1 ) c 2 = c 1 ∗ f 2 + t a n h ( W c ⋅ [ h 1 , x 2 ] ) ∗ i 2 h 2 = o 2 ∗ t a n h ( c 2 ) c 3 = c 2 ∗ f 3 + t a n h ( W c ⋅ [ h 2 , x 3 ] ) ∗ i 3 h 3 = o 3 ∗ t a n h ( c 3 ) c_1 = c_0 * f_1+tanh(W_c\cdot [h_0,x_1])*i_1\\ h_1 = o_1*tanh(c_1)\\ c_2 = c_1 * f_2+tanh(W_c\cdot [h_1,x_2])*i_2\\ h_2 = o_2*tanh(c_2)\\ c_3 = c_2 * f_3+tanh(W_c\cdot [h_2,x_3])*i_3\\ h_3 = o_3*tanh(c_3)\\ c1=c0f1+tanh(Wc[h0,x1])i1h1=o1tanh(c1)c2=c1f2+tanh(Wc[h1,x2])i2h2=o2tanh(c2)c3=c2f3+tanh(Wc[h2,x3])i3h3=o3tanh(c3)
现在要去更新参数,共计四个 W f , W c , W i , W o W_f,W_c,W_i,W_o Wf,Wc,Wi,Wo
∂ E 3 ∂ W f = ∂ E 3 ∂ h 3 ∂ h 3 ∂ c 3 ∂ c 3 ∂ c 2 ∂ c 2 ∂ c 1 ∂ c 1 ∂ W f \frac{\partial E_3}{\partial W_f} = \frac{\partial E_3}{\partial h_3}\frac{\partial h_3}{\partial c_3}\frac{\partial c_3}{\partial c_2}\frac{\partial c_2}{\partial c_1}\frac{\partial c_1}{\partial W_f} WfE3=h3E3c3h3c2c3c1c2Wfc1
求偏导过程中,主要看函数中有 W f W_f Wf的是哪部分,很快我们发现 c i c_i ci函数中直观的包含了 W c , W i W_c,W_i Wc,Wi W f W_f Wf,因此 W c , W i W_c,W_i Wc,Wi W f W_f Wf的偏导公式都相同。
再看参数 W o W_o Wo,它与 h 3 h_3 h3是直接关系。
∂ E 3 ∂ W o = ∂ E 3 ∂ h 3 ∂ h 3 ∂ W o \frac{\partial E_3}{\partial W_o} = \frac{\partial E_3}{\partial h_3}\frac{\partial h_3}{\partial W_o} WoE3=h3E3Woh3
因此我们也可以总结出和RNN类似的公式,即任意时刻误差项公式
∂ E k ∂ W c = ∂ E k ∂ h k ∂ h k ∂ c k ( ∏ t = 2 k ∂ c t ∂ c t − 1 ) ∂ c 1 ∂ W c \frac{\partial E_k}{\partial W_c}=\frac{\partial E_k}{\partial h_k}\frac{\partial h_k}{\partial c_k}(\prod_{t=2}^k\frac{\partial c_t}{\partial c_{t-1}})\frac{\partial c_1}{\partial W_c} WcEk=hkEkckhk(t=2kct1ct)Wcc1
W c , W f , W i W_c,W_f,W_i Wc,Wf,Wi 同上。
接着看到 ∂ c t ∂ c t − 1 \frac{\partial c_t}{\partial c_{t-1}} ct1ct,举个t=2的例子:
c 2 = c 1 ∗ f 2 + c 2 ~ ∗ i 2 ∂ c 2 ∂ c 1 = ∂ ( c 1 ∗ f 2 ) ∂ c 1 + ∂ ( c 2 ~ ∗ i 2 ) ∂ c 1 = c 1 ⋅ f 2 ′ + c 1 ′ ⋅ f 2 + c 2 ~ ⋅ i 2 ′ + c 2 ~ ′ ⋅ i 2 = c 1 ⋅ ∂ f 2 ∂ c 1 + f 2 + c 2 ~ ⋅ ∂ i 2 ∂ c 1 + i 2 ⋅ ∂ c 2 ~ ∂ c 1 = c 1 ⋅ ∂ f 2 ∂ h 1 ∂ h 1 ∂ c 1 + f 2 + c 2 ~ ⋅ ∂ i 2 ∂ h 1 ∂ h 1 ∂ c 1 + i 2 ⋅ ∂ c 2 ~ ∂ h 1 ∂ h 1 ∂ c 1 = c 1 ⋅ σ ′ ( W f ⋅ [ h 1 , x 2 ] ) ⋅ W f ⋅ o 1 ⋅ t a n h ′ ( c 1 )     + f 2     + c 2 ~ ⋅ σ ′ ( W i ⋅ [ h 1 , x 2 ] ) ⋅ W i ⋅ o 1 ⋅ t a n h ′ ( c 1 )     + i 2 ⋅ σ ′ ( W c ⋅ [ h 1 , x 2 ] ) ⋅ W c ⋅ o 1 ⋅ t a n h ′ ( c 1 ) \begin{aligned} c_2 &= c_1 * f_2+\tilde{c_2}*i_2\\ \frac{\partial c_2}{\partial c_1} &= \frac{\partial (c_1*f_2)}{\partial c_1}+\frac{\partial (\tilde{c_2}*i_2)}{\partial c_1}\\ &=c_1 \cdot f_2'+c_1'\cdot f_2+\tilde{c_2}\cdot i_2'+\tilde{c_2}'\cdot i_2\\ &=c_1\cdot \frac{\partial f_2}{\partial c_1}+f_2+\tilde{c_2}\cdot \frac{\partial i_2}{\partial c_1}+i_2\cdot \frac{\partial \tilde{c_2}}{\partial c_1} \\ & = c_1\cdot \frac{\partial f_2}{\partial h_1}\frac{\partial h_1}{\partial c_1}+f_2+\tilde{c_2}\cdot \frac{\partial i_2}{\partial h_1}\frac{\partial h_1}{\partial c_1}+i_2\cdot \frac{\partial \tilde{c_2}}{\partial h_1}\frac{\partial h_1}{\partial c_1}\\ &=c_1\cdot \sigma'(W_f\cdot [h_1,x_2]) \cdot W_f \cdot o_1 \cdot tanh'(c_1)\\ &\ \ \ +f_2\\ &\ \ \ +\tilde{c_2}\cdot \sigma'(W_i\cdot [h_1,x_2]) \cdot W_i \cdot o_1 \cdot tanh'(c_1)\\ &\ \ \ + i_2 \cdot \sigma'(W_c\cdot [h_1,x_2]) \cdot W_c \cdot o_1 \cdot tanh'(c_1) \end{aligned} c2c1c2=c1f2+c2~i2=c1(c1f2)+c1(c2~i2)=c1f2+c1f2+c2~i2+c2~i2=c1c1f2+f2+c2~c1i2+i2c1c2~=c1h1f2c1h1+f2+c2~h1i2c1h1+i2h1c2~c1h1=c1σ(Wf[h1,x2])Wfo1tanh(c1)   +f2   +c2~σ(Wi[h1,x2])Wio1tanh(c1)   +i2σ(Wc[h1,x2])Wco1tanh(c1)

注意!我们之所以不直接把 c 1 c_1 c1提出来,是因为 f 2 f_2 f2 中是包含了 c 1 c_1 c1 的!! f 2 f_2 f2包含了 h 1 h_1 h1,而 h 1 h_1 h1包含了 c 1 c_1 c1
作为简单的记忆,我们就把 ∂ c 2 ∂ c 1 \frac{\partial c_2}{\partial c_1} c1c2拆成了四项,除了 f t f_t ft那一项,其他都是一个套路。

配上一个可视的BPTT

总结一下偏导公式:
∂ c t ∂ c t − 1 = ∂ ∂ c t − 1 [ c t − 1 ∗ f t + c t ~ ∗ i t ] = ∂ ∂ c t − 1 [ c t − 1 ∗ f t ] + ∂ ∂ c t − 1 [ c t ~ ∗ i t ] = ∂ f t ∂ c t − 1 ⋅ c t − 1 + ∂ c t − 1 ∂ c t − 1 ⋅ f t + ∂ i t ∂ c t − 1 ⋅ c t ~ + ∂ c t ~ ∂ c t − 1 ⋅ i t = ∂ f t ∂ h t − 1 ⋅ ∂ h t − 1 ∂ c t − 1 ⋅ c t − 1 + ∂ i t ∂ h t − 1 ⋅ ∂ h t − 1 ∂ c t − 1 ⋅ c t ~ + ∂ c t ~ ∂ h t − 1 ⋅ ∂ h t − 1 ∂ c t − 1 ⋅ i t = σ ′ ( W f ⋅ [ h t − 1 , x t ] ) ⋅ W f ⋅ o t − 1 ⋅ t a n h ′ ( c t − 1 ) ⋅ c t − 1     + f t     + σ ′ ( W i ⋅ [ h t − 1 , x t ] ) ⋅ W i ⋅ o t − 1 ⋅ t a n h ′ ( c t − 1 ) ⋅ c t ~     + σ ′ ( W c ⋅ [ h t − 1 , x t ] ) ⋅ W c ⋅ o t − 1 ⋅ t a n h ′ ( c t − 1 ) ⋅ i t \begin{aligned} \frac{\partial c_t}{\partial c_{t-1}}&=\frac{\partial }{\partial c_{t-1}}[c_{t-1} * f_t + \tilde{c_t} * i_t]\\ &= \frac{\partial }{\partial c_{t-1}}[c_{t-1} * f_t]+\frac{\partial }{\partial c_{t-1}}[\tilde{c_{t}} * i_t]\\ &=\frac{\partial f_t}{\partial c_{t-1}}\cdot c_{t-1}+\frac{\partial c_{t-1}}{\partial c_{t-1}}\cdot f_t+\frac{\partial i_t}{\partial c_{t-1}}\cdot \tilde{c_t}+\frac{\partial \tilde{c_t}}{\partial c_{t-1}}\cdot i_t\\ &= \frac{\partial f_t}{\partial h_{t-1}} \cdot \frac{\partial h_{t-1}}{\partial c_{t-1}}\cdot c_{t-1}+\frac{\partial i_t}{\partial h_{t-1}}\cdot \frac{\partial h_{t-1}}{\partial c_{t-1}}\cdot\tilde{c_t}+\frac{\partial \tilde{c_t}}{\partial h_{t-1}}\cdot \frac{\partial h_{t-1}}{\partial c_{t-1}}\cdot i_t\\ &= \sigma'(W_f\cdot [h_{t-1},x_t]) \cdot W_f \cdot o_{t-1} \cdot tanh'(c_{t-1}) \cdot c_{t-1}\\ & \ \ \ +f_t\\ & \ \ \ + \sigma'(W_i\cdot [h_{t-1},x_t]) \cdot W_i \cdot o_{t-1} \cdot tanh'(c_{t-1}) \cdot \tilde{c_t}\\ & \ \ \ + \sigma'(W_c\cdot [h_{t-1},x_t]) \cdot W_c \cdot o_{t-1} \cdot tanh'(c_{t-1}) \cdot i_t \end{aligned} ct1ct=ct1[ct1ft+ct~it]=ct1[ct1ft]+ct1[ct~it]=ct1ftct1+ct1ct1ft+ct1itct~+ct1ct~it=ht1ftct1ht1ct1+ht1itct1ht1ct~+ht1ct~ct1ht1it=σ(Wf[ht1,xt])Wfot1tanh(ct1)ct1   +ft   +σ(Wi[ht1,xt])Wiot1tanh(ct1)ct~   +σ(Wc[ht1,xt])Wcot1tanh(ct1)it
如果把这四项分别用 A , B , C , D A,B,C,D A,B,C,D替代的话,公式就可以变成:
∂ c t ∂ c t − 1 = A t + B t + C t + D t \frac{\partial c_t}{\partial c_{t-1}}=A_t+B_t+C_t+D_t ct1ct=At+Bt+Ct+Dt

把这个简洁的公式带入之前的误差公式:
∂ E k ∂ W = ∂ E k ∂ h k ∂ h k ∂ c k ( ∏ t = 2 k ∂ c t ∂ c t − 1 ) ∂ c 1 ∂ W = ∂ E k ∂ h k ∂ h k ∂ c k ( ∏ t = 2 k [ A t + B t + C t + D t ] ) ∂ c 1 ∂ W \begin{aligned} \frac{\partial E_k}{\partial W}&=\frac{\partial E_k}{\partial h_k}\frac{\partial h_k}{\partial c_k}(\prod_{t=2}^k\frac{\partial c_t}{\partial c_{t-1}})\frac{\partial c_1}{\partial W}\\ &=\frac{\partial E_k}{\partial h_k}\frac{\partial h_k}{\partial c_k}(\prod_{t=2}^k[A_t+B_t+C_t+D_t])\frac{\partial c_1}{\partial W} \end{aligned} WEk=hkEkckhk(t=2kct1ct)Wc1=hkEkckhk(t=2k[At+Bt+Ct+Dt])Wc1

缓解梯度消失/爆炸问题

有连乘,那就说明有可能造成梯度消失和爆炸。上文也讲了 ∏ t = 2 k ∂ c t ∂ c t − 1 \prod_{t=2}^k\frac{\partial c_t}{\partial c_{t-1}} t=2kct1ct里面有什么,总共四项,如果看的云里雾里也没事,因为那个 B t B_t Bt 你一定看的懂!因为 B t B_t Bt 只有一个内容 f t f_t ft,我们可以轻松地直观地通过他调整 f t f_t ft 的大小以适应其他三个项的值,然后是的连乘出来的结果不会非常小。

接下来我们看 f t f_t ft 到底怎么能帮助我们。现在假设对某一个时间步 k < T k<T k<T,我们有:
∑ t = 1 k ∂ E t ∂ W → 0 \sum_{t=1}^k \frac{\partial E_t}{\partial W}\rightarrow 0 t=1kWEt0
然后为了梯度不消失,我们可以再时间步 k + 1 k+1 k+1 找到一个合适的 W f W_f Wf 使得:
∂ E k + 1 ∂ W ↛ 0 \frac{\partial E_{k+1}}{\partial W}\nrightarrow 0 WEk+10
由于遗忘门的激活函数和梯度项中大家都是相加的(A,B,C,D,加性结构),所以使得 LSTM 在任何时间步都能找到这样的 W f W_f Wf 使得:
∑ t = 1 k + 1 ∂ E t ∂ W ↛ 0 \sum_{t=1}^{k+1}\frac{\partial E_t}{\partial W}\nrightarrow 0 t=1k+1WEt0
这样梯度就不会消失了。

另一个重要的性质: 正如上文说到的加性结构,四个项可以相互平衡从而保证在反向传播的时候梯度值不会消失。

举个例子:假设 时间步 t ∈ { 2 , 3 , . . . , k } t \in \{2,3,...,k\} t{2,3,...,k},我们对梯度值中的四项设置一个相互平衡的值(从而保证梯度不消失):
A t ≈ 0.1 ⃗ , B t = f t ≈ 0.7 ⃗ , C t ≈ 0.1 ⃗ , D t ≈ 0.1 ⃗ A_t \approx \vec{0.1},B_t = f_t \approx \vec{0.7},C_t \approx \vec{0.1},D_t \approx \vec{0.1} At0.1 ,Bt=ft0.7 ,Ct0.1 ,Dt0.1
带入连乘公式:
∏ t = 2 k [ A t + B t + C t + D t ] ≈ ∏ t = 2 k [ 0.1 ⃗ + 0.7 ⃗ + 0.1 ⃗ + 0.1 ⃗ ] ≈ ∏ t = 2 k 1 ⃗ ↛ 0 \begin{aligned} \prod_{t=2}^k[A_t+B_t+C_t+D_t] &\approx \prod_{t=2}^k[\vec{0.1}+\vec{0.7}+\vec{0.1}+\vec{0.1}]\\ & \approx \prod_{t=2}^k \vec{1} \nrightarrow 0\end{aligned} t=2k[At+Bt+Ct+Dt]t=2k[0.1 +0.7 +0.1 +0.1 ]t=2k1 0
这时候就算是连乘,梯度也不会消失了。

所以,在 LSTM 中,遗忘门的存在,以及细胞状态梯度的加性特性,使网络能够以这样一种方式更新参数,即不同子梯度之间的平衡从而避免梯度消失
但看到这,也就清楚了,因为我们都是正数相加,所以不能够避免梯度爆炸,当 A , C , D A,C,D A,C,D的数值很大的时候, f t f_t ft 也没办法去平衡防止梯度爆炸。

  • 9
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值