DL notes 01:RNN/LSTM/GRU

一、RNN基本结构、梯度消失和梯度爆炸的原因

线性计算单元组成的RNN结构是最简单的一种,我们以此为例来说明造成梯度消失和梯度爆炸的原因:
RNN-simple.png
上图为线性计算单元组成的RNN.依据上图,我们现假设存在一个RNN模型仅包含一个隐藏层,整个RNN模型关注的时间步数为3, H 0 H_{0} H0是隐藏层的初始状态,则可以用如下算式表示前向传播过程:
H 1 = W X X 1 + W H H 0 + b H , O 1 = W O H 1 + b O H_{1} = W_{X}X_{1} + W_{H}H_{0} + b_{H}, O{1} = W_{O}H_{1} + b_{O} H1=WXX1+WHH0+bH,O1=WOH1+bO H 2 = W X X 2 + W H H 1 + b H , O 2 = W O H 2 + b O H_{2} = W_{X}X_{2} + W_{H}H_{1} + b_{H}, O{2} = W_{O}H_{2} + b_{O} H2=WXX2+WHH1+bH,O2=WOH2+bO H 3 = W X X 3 + W H H 2 + b H , O 3 = W O H 3 + b O H_{3} = W_{X}X_{3} + W_{H}H_{2} + b_{H}, O{3} = W_{O}H_{3} + b_{O} H3=WXX3+WHH2+bH,O3=WOH3+bO

其中 X = [ X 0 , X 1 , X 2 , . . X t ] X =[X_{0},X_{1},X_{2},..X_{t}] X=[X0,X1,X2,..Xt], Y = [ Y 0 , Y 1 , Y 2 , . . Y t ] Y =[Y_{0},Y_{1},Y_{2},..Y_{t}] Y=[Y0,Y1,Y2,..Yt]代表 t t t时间步长的特征数据和真值标签。 H = [ H 0 , H 1 , H 2 , . . H t ] H =[H_{0},H_{1},H_{2},..H_{t}] H=[H0,H1,H2,..Ht] 代表RNN中的隐节点状态, W X , W H , W O W_{X},W_{H},W_{O} WX,WH,WO是各层的权重, b H , b O b_{H},b_{O} bH,bO是偏置。
因为是最简单的线性计算单元,我们假设最后使用MSE作为损失函数: L t = 1 2 ( Y t − O t ) 2 L_{t} = \frac{1}{2}\left(Y{t}-O_{t}\right)^{2} Lt=21(YtOt)2
对于整个RNN的训练,我们通常需要统计每个时刻的损失之和或平均值,这里我们以每个时间步数累积损失作为整个模型的训练损失:
L = ∑ t = 0 T L t L = \sum_{t=0}^{T} L_{t} L=t=0TLt
假设目前只考虑初始的三个时间步长,我们选取最长一条反向传播通路为例,即以 L 3 L_{3} L3对RNN中的参数求偏导:
∂ L 3 ∂ W O = ∂ L 3 ∂ O 3 ∂ O 3 ∂ W O \frac{\partial L_{3}}{\partial W_{O}} = \frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial W_{O}} WOL3=O3L3WOO3

∂ L 3 ∂ W X = ∂ L 3 ∂ O 3 ∂ O 3 ∂ H 3 ∂ H 3 ∂ W X + ∂ L 3 ∂ O 3 ∂ O 3 ∂ H 3 ∂ H 3 ∂ H 2 ∂ H 2 ∂ W X + ∂ L 3 ∂ O 3 ∂ O 3 ∂ H 3 ∂ H 3 ∂ H 2 ∂ H 2 ∂ H 1 ∂ H 1 ∂ W X \frac{\partial L_{3}}{\partial W_{X}} = \frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial H_{3}} \frac{\partial H_{3}}{\partial W_{X}} + \frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial H_{3}} \frac{\partial H_{3}}{\partial H_{2}} \frac{\partial H_{2}}{\partial W_{X}} + \frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial H_{3}} \frac{\partial H_{3}}{\partial H_{2}} \frac{\partial H_{2}}{\partial H_{1}} \frac{\partial H_{1}}{\partial W_{X}} WXL3=O3L3H3O3WXH3+O3L3H3O3H2H3WXH2+O3L3H3O3H2H3H1H2WXH1

∂ L 3 ∂ W H = ∂ L 3 ∂ O 3 ∂ O 3 ∂ H 3 ∂ H 3 ∂ W H + ∂ L 3 ∂ O 3 ∂ O 3 ∂ H 3 ∂ H 3 ∂ H 2 ∂ H 2 ∂ W H + ∂ L 3 ∂ O 3 ∂ O 3 ∂ H 3 ∂ H 3 ∂ H 2 ∂ H 2 ∂ H 1 ∂ H 1 ∂ W H \frac{\partial L_{3}}{\partial W_{H}} = \frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial H_{3}} \frac{\partial H_{3}}{\partial W_{H}} + \frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial H_{3}} \frac{\partial H_{3}}{\partial H_{2}} \frac{\partial H_{2}}{\partial W_{H}} + \frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial H_{3}} \frac{\partial H_{3}}{\partial H_{2}} \frac{\partial H_{2}}{\partial H_{1}} \frac{\partial H_{1}}{\partial W_{H}} WHL3=O3L3H3O3WHH3+O3L3H3O3H2H3WHH2+O3L3H3O3H2H3H1H2WHH1

从中可以发现根据链式法则求导, W X , W H W_{X},W_{H} WX,WH 会重复出现,这是由于 W X , W H W_{X},W_{H} WX,WH在每个时间步中都参与隐节点 H H H的状态估计,随着时间步数的增加,反向传播的路径将会相应的延长。相应的,我们可以总结出任意时刻 t t t损失函数对 W X W_{X} WX的偏导:
∂ L t ∂ W X = ∑ k = 0 t ∂ L t ∂ O t ∂ O t ∂ H t ( ∏ j = k + 1 t ∂ H j ∂ H j − 1 ) ∂ H k ∂ W X \frac{\partial L_{t}}{\partial W_{X}} = \sum_{k=0}^{t} \frac{\partial L_{t}}{\partial O_{t}} \frac{\partial O_{t}}{\partial H_{t}} \left( \prod_{j=k+1}^{t} \frac{\partial H_{j}}{\partial H_{j-1}} \right) \frac{\partial H_{k}}{\partial W_{X}} WXLt=k=0tOtLtHtOtj=k+1tHj1HjWXHk
任意时刻 t t t损失函数对 W H W_{H} WH求偏导同上,只需替换 W X W_{X} WX
如果隐藏层的激活函数为 t a n h tanh tanh,则有 H j = t a n h ( W X X j + W H H j − 1 + b H ) H_{j} = tanh(W_{X}X_{j}+W_{H}H_{j-1}+b_{H}) Hj=tanh(WXXj+WHHj1+bH) ∏ j = k + 1 t ∂ H j ∂ H j − 1 = ∏ j = k + 1 t t a n h ′ ⋅ W H \prod_{j=k+1}^{t} \frac{\partial H_{j}}{\partial H_{j-1}} = \prod_{j=k+1}^{t} tanh^{'}·W_{H} j=k+1tHj1Hj=j=k+1ttanhWH
激活函数 t a n h tanh tanh的导数有如下特性:
t a n h ′ x = 1 − t a n h 2 x tanh^{'} x= 1-tanh^{2}x tanhx=1tanh2x
因此可知 0 < t a n h ′ ≤ 1 0< tanh^{'} \le 1 0<tanh1
在训练过程中, H j H_{j} Hj的状态极少情况下为0, t a n h ′ tanh^{'} tanh通常是小于1的,如果 W H W_{H} WH值在 ( 0 , 1 ) (0,1) (0,1)范围内,则 ∏ j = k + 1 t t a n h ′ ⋅ W H \prod_{j=k+1}^{t} tanh^{'}·W_{H} j=k+1ttanhWH 趋近于0,导致梯度消失。
如果 W H W_{H} WH值很大,则会导致 ∏ j = k + 1 t t a n h ′ ⋅ W H \prod_{j=k+1}^{t} tanh^{'}·W_{H} j=k+1ttanhWH 连乘后结果趋近于无穷,导致梯度爆炸。
如何避免这种现象?在RNN的课程学习中,提到裁剪梯度的方法,假设我们把所有模型参数的梯度拼接成一个向量 g \boldsymbol{g} g ,并设裁剪的阈值是 θ \theta θ 。裁剪后的梯度即:
min ⁡ ( θ ∥ g ∥ , 1 ) g \min\left(\frac{\theta}{\|\boldsymbol{g}\|}, 1\right)\boldsymbol{g} min(gθ,1)g
梯度的的 L 2 L_{2} L2 范数不超过 θ \theta θ。函数如下:

def grad_clipping(params, theta, device):
    norm = torch.tensor([0.0], device=device)
    for param in params:
        norm += (param.grad.data ** 2).sum()
    norm = norm.sqrt().item()
    if norm > theta:
        for param in params:
            param.grad.data *= (theta / norm)

裁剪梯度是一个简单好用的防止梯度爆炸的方法,实际上造成梯度衰减或梯度爆炸的根本原因是 ∏ j = k + 1 t ∂ H j ∂ H j − 1 \prod_{j=k+1}^{t} \frac{\partial H_{j}}{\partial H_{j-1}} j=k+1tHj1Hj这一连乘项,理想的消除方法就是使 ∂ H j ∂ H j − 1 ∈ [ 0 , 1 ] \frac{\partial H_{j}}{\partial H_{j-1}} \in [0,1] Hj1Hj[0,1]其实这就是LSTM做的事情,至于细节如何则会在后续的篇幅中加以介绍。

二、LSTM

长短期记忆(Long short-term memory, LSTM)是一种特殊的RNN结构单元,主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。
lstm
I t = σ ( X t W x i + H t − 1 W h i + b i ) I_t = σ(X_tW_{xi} + H_{t−1}W_{hi} + b_i) It=σ(XtWxi+Ht1Whi+bi) F t = σ ( X t W x f + H t − 1 W h f + b f ) F_t = σ(X_tW_{xf} + H_{t−1}W_{hf} + b_f) Ft=σ(XtWxf+Ht1Whf+bf) O t = σ ( X t W x o + H t − 1 W h o + b o ) O_t = σ(X_tW_{xo} + H_{t−1}W_{ho} + b_o) Ot=σ(XtWxo+Ht1Who+bo) C ~ t = t a n h ( X t W x c + H t − 1 W h c + b c ) \widetilde{C}_t = tanh(X_tW_{xc} + H_{t−1}W_{hc} + b_c) C t=tanh(XtWxc+Ht1Whc+bc) C t = F t ⊙ C t − 1 + I t ⊙ C ~ t C_t = F_t ⊙C_{t−1} + I_t ⊙\widetilde{C}_t Ct=FtCt1+ItC t H t = O t ⊙ t a n h ( C t ) H_t = O_t⊙tanh(C_t) Ht=Ottanh(Ct) Y t = ϕ ( H t W h y + b y ) Y_{t} = \phi(H_{t}W_{hy}+b_{y}) Yt=ϕ(HtWhy+by)
相比RNN只有一个传递状态 H t H_{t} Ht,LSTM有两个传输状态,一个 C t C_{t} Ct (cell state),和一个 H t H_{t} Ht (hidden state)。其中对于传递下去的 C t C_{t} Ct 改变得很慢,通常输出的 C t C_{t} Ct 是上一个状态传过来的 C t − 1 C_{t-1} Ct1 加上一些数值。而 H t H_{t} Ht 则在不同节点下往往会有很大的区别。 ⊙ \odot 是Hadamard Product,也就是操作矩阵中对应的元素相乘,因此要求两个相乘矩阵是同型的。 ⊕ \oplus 则代表进行矩阵加法。

LSTM主要包括以下几个结构:

  • 遗忘门:控制上一时间步的记忆细胞
  • 输入门:控制当前时间步的输入
  • 输出门:控制从记忆细胞到隐藏状态
  • 记忆细胞:⼀种特殊的隐藏状态的信息的流动
    三个门输出都经过诸如 s i g m o i d sigmoid sigmoid的激活函数,映射到 ( 0 , 1 ) (0,1) (0,1)的范围,形成门控状态。而记忆细胞则是通过 t a n h tanh tanh转换成 ( − 1 , 1 ) (-1,1) (1,1)的范围,这里作为记忆细胞短期依赖输出而非门控信号。

LSTM 内部主要有三个阶段:

  1. 忘记阶段。这个阶段主要是对上一个节点传进来的输入进行选择性忘记。简单来说就是会 “忘记不重要的,记住重要的”。具体来说是通过计算得到的 F t F_{t} Ft (f表示forget)来作为忘记门控,来控制上一个状态的 C t − 1 C_{t-1} Ct1 哪些需要留哪些需要忘。

  2. 选择记忆阶段。这个阶段将这个阶段的输入有选择性地进行“记忆”。主要是会对输入 X t X_{t} Xt进行选择记忆。哪些重要则着重记录下来,哪些不重要,则少记一些。当前的输入内容由前面计算得到的 C ~ t \widetilde{C}_t C t表示。而选择的门控信号则是由 I t I_{t} It (i代表information)来进行控制。

将上面两步得到的结果相加,即可得到传输给下一个状态的 C t C_{t} Ct 。也就是上图中的第一个公式。这里也使用 t a n h tanh tanh起到对输入的信息进行压缩的作用。

  1. 输出阶段。这个阶段将决定哪些将会被当成当前状态的输出。主要是通过 O t O_{t} Ot来进行控制的。并且还对上一阶段得到的 C t C_{t} Ct进行了放缩(通过一个tanh激活函数进行变化)。

与普通RNN类似,输出 Y t Y_{t} Yt往往最终也是通过 H t H_{t} Ht变化得到。

三、GRU

GRU(Gate Recurrent Unit)是循环神经网络(Recurrent Neural Network, RNN)的一种。和LSTM(Long-Short Term Memory)一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。

GRU和LSTM在很多情况下实际表现上相差无几,那么为什么我们要使用新人GRU(2014年提出)而不是相对经受了更多考验的LSTM(1997提出)呢?其实通过代码我们就可以发现,GRU的权重参数相比LSTM减少了1/4。相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。GRU
R t = σ ( X t W x r + H t − 1 W h r + b r ) R_{t} = σ(X_tW_{xr} + H_{t−1}W_{hr} + b_r) Rt=σ(XtWxr+Ht1Whr+br) Z t = σ ( X t W x z + H t − 1 W h z + b z ) Z_{t} = σ(X_tW_{xz} + H_{t−1}W_{hz} + b_z) Zt=σ(XtWxz+Ht1Whz+bz) H ~ t − 1 = R t ⊙ H t − 1 \widetilde{H}_{t-1} = R_t ⊙H_{t−1} H t1=RtHt1 H ~ t = t a n h ( X t W x h + H ~ t − 1 W h h + b h ) \widetilde{H}_t = tanh(X_tW_{xh} + \widetilde{H}_{t-1}W_{hh} + b_h) H t=tanh(XtWxh+H t1Whh+bh) H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t H_t = Z_t⊙H_{t−1} + (1−Z_t)⊙\widetilde{H}_t Ht=ZtHt1+(1Zt)H t

GRU很聪明的一点就在于,我们使用了同一个门控 Z t Z_t Zt 就同时可以进行遗忘和选择记忆(LSTM则要使用多个门控)。 R t R_t Rt在GRU中被称作重置⻔有助于捕捉时间序列⾥短期的依赖关系; Z t Z_t Zt被称作更新⻔有助于捕捉时间序列⾥⻓期的依赖关系。
Z t ⊙ H t − 1 Z_t⊙H_{t−1} ZtHt1:表示对原本隐藏状态的选择性“遗忘”。这里的 Z t Z_t Zt 可以想象成遗忘门(forget gate),忘记 H t − 1 H_{t−1} Ht1 维度中一些不重要的信息。
( 1 − Z t ) ⊙ H ~ t (1−Z_t)⊙\widetilde{H}_t (1Zt)H t: 表示对包含当前节点信息的 H ~ t \widetilde{H}_t H t 进行选择性”记忆“。与上面类似,这里的 ( 1 − Z t ) (1−Z_t) (1Zt)同理会忘记 H ~ t \widetilde{H}_t H t维度中的一些不重要的信息。或者,这里我们更应当看做是对 H ~ t \widetilde{H}_t H t 维度中的某些信息进行选择。
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ~ t H_t = Z_t⊙H_{t−1} + (1−Z_t)⊙\widetilde{H}_t Ht=ZtHt1+(1Zt)H t :结合上述,这一步的操作就是忘记传递下来的 H t − 1 H_{t−1} Ht1 中的某些维度信息,并加入当前节点输入的某些维度信息。

可以看到,这里的遗忘 Z t Z_t Zt和选择 ( 1 − Z t ) (1−Z_t) (1Zt) 是联动的。也就是说,对于传递进来的维度信息,我们会进行选择性遗忘,则遗忘了多少权重 ( Z t Z_t Zt),我们就会使用包含当前输入的 H ~ t \widetilde{H}_t H t 中所对应的权重进行弥补 ( 1 − Z t ) (1−Z_t) (1Zt)。以保持一种”恒定“状态。

四、深度循环神经网络

deepRNN
H t ( 1 ) = ϕ ( X t W x h ( 1 ) + H t − 1 ( 1 ) W h h ( 1 ) + b h ( 1 ) ) \boldsymbol{H}_t^{(1)} = \phi(\boldsymbol{X}_t \boldsymbol{W}_{xh}^{(1)} + \boldsymbol{H}_{t-1}^{(1)} \boldsymbol{W}_{hh}^{(1)} + \boldsymbol{b}_h^{(1)}) Ht(1)=ϕ(XtWxh(1)+Ht1(1)Whh(1)+bh(1)) H t ( ℓ ) = ϕ ( H t ( ℓ − 1 ) W x h ( ℓ ) + H t − 1 ( ℓ ) W h h ( ℓ ) + b h ( ℓ ) ) \boldsymbol{H}_t^{(\ell)} = \phi(\boldsymbol{H}_t^{(\ell-1)} \boldsymbol{W}_{xh}^{(\ell)} + \boldsymbol{H}_{t-1}^{(\ell)} \boldsymbol{W}_{hh}^{(\ell)} + \boldsymbol{b}_h^{(\ell)}) Ht()=ϕ(Ht(1)Wxh()+Ht1()Whh()+bh()) O t = H t ( L ) W h q + b q \boldsymbol{O}_t = \boldsymbol{H}_t^{(L)} \boldsymbol{W}_{hq} + \boldsymbol{b}_q Ot=Ht(L)Whq+bq
在pytorch的实现中以num_layers参数进行层数的设置和调整。

五、双向循环神经网络

Bi-RNN
H → t = ϕ ( X t W x h ( f ) + H → t − 1 W h h ( f ) + b h ( f ) ) \overrightarrow{\boldsymbol{H}}_t = \phi(\boldsymbol{X}_t \boldsymbol{W}_{xh}^{(f)} + \overrightarrow{\boldsymbol{H}}_{t-1} \boldsymbol{W}_{hh}^{(f)} + \boldsymbol{b}_h^{(f)}) H t=ϕ(XtWxh(f)+H t1Whh(f)+bh(f)) H ← t = ϕ ( X t W x h ( b ) + H ← t + 1 W h h ( b ) + b h ( b ) ) \overleftarrow{\boldsymbol{H}}_t = \phi(\boldsymbol{X}_t \boldsymbol{W}_{xh}^{(b)} + \overleftarrow{\boldsymbol{H}}_{t+1} \boldsymbol{W}_{hh}^{(b)} + \boldsymbol{b}_h^{(b)}) H t=ϕ(XtWxh(b)+H t+1Whh(b)+bh(b)) H t = ( H → t , H ← t ) \boldsymbol{H}_t=(\overrightarrow{\boldsymbol{H}}_{t}, \overleftarrow{\boldsymbol{H}}_t) Ht=(H t,H t) O t = H t W h q + b q \boldsymbol{O}_t = \boldsymbol{H}_t \boldsymbol{W}_{hq} + \boldsymbol{b}_q Ot=HtWhq+bq
在pytorch的实现中以bidirectional参数进行层数的设置和调整。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值