RNN和LSTM的理论推导


注:LSTM-BP待日后填坑

第一章 RNN的提出背景和作用

1.1RNN的提出背景

       在传统的神经网络中,各个输入在算法内部是相对独立的,无法从先前的信息中进行推理,难以处理序列类型的数据。

1.2RNN的作用

        在BP算法提出之后,学者和研究员们又提出了具有短期记忆能力的循环神经网络。RNN能记住之前的输入值,即使后面有相同的输入,输出值也会不同。目前已被广泛应用在语音识别,语言模型以及自然语言生成等任务上。

第二章 RNN的理论推导

2.1 RNN的网络结构

       RNN的基本结构如下图所示,左侧是RNN网络,右侧是RNN网络按时序展开的形式。
图源wiki

       RNN在传统神经网络输入层——隐藏层——输出层的基础上增加了一个类似于延时器的单元,能够记录上一次的输出值,并带到下一次的输入当中,以此记录最近几次的活性值V,RNN也由此具备了短期记忆能力。
       将RNN按时序展开就可以得到RNN的结构

  • 代表t时刻的输入
  • 代表t时刻隐藏层的状态
  • 代表t时刻的输出
  • 代表输入层到隐藏层的权重
  • V代表隐藏状态到下一隐藏状态的权重
  • W代表隐藏层到输出层的权重

       U,V,W是该模型中的线性关系参数,它在整个网络中是共享的,体现出RNN模型“循环反馈”的思想。

2.2 RNN的前向传播

       RNN以时刻t为参数进行循环。虽然每一时刻的输入不同,但其对应的结构不变,所以每一次循环就相当于在进行递归,递推公式如下:

h t = σ ( x t U + h t − 1 V + b h ) h_{t}=\sigma(x_{t}U+h_{t-1}V+b_{h}) ht=σ(xtU+ht1V+bh)

       其中 σ \sigma σ为RNN的激活函数,一般为tanh。是该线性关系中的的偏置。
       最终输出的表达式为:
y t ^ = φ ( W h t + b y ) \hat{y_{t}}=\varphi(Wh_{t}+b_{y}) yt^=φ(Wht+by)
       激活函数 φ \varphi φ一般是softmax。是该线性关系中的的偏置。

2.3 RNN的反向传播

       类似于传统神经网络,RNN神经网络的反向传播算法思路也是通过梯度下降法一轮轮的选代。由于是基于时间的反向传播,所以将RNN神经网络的反向传播命名为BPTT(back-propagation through time)。我们利用反向传播算法将输出层的误差加和,然后对各个权重的参数矩阵求梯度,再利用梯度下降法更新各个权重。
       对于每一时刻t的RNN网络,网络的输出在每个时刻都会产生损失。那么总的损失为。我们的目标就是要求取对的偏导。
       对于预测结果的任意损失函数,求取是最简单的,我们可以直接求取每个时刻的,由于它不存在和之前的依赖状态,可以直接求导取得,然后简单求和即可,算式如下:
∂ l o s s ∂ W = ∑ t = 1 T ∂ L t ∂ y t ^ ⋅ ∂ y t ^ ∂ W \frac{\partial loss}{\partial W}=\sum_{t=1}^{T}\mathop{\frac{\partial L_{t}}{\partial\hat{y_{t}}}\cdot\frac{\partial\hat{y_{t}}}{\partial W}} Wloss=t=1Tyt^LtWyt^
       对于V的计算不能直接求导,因此需要用链式求导法则。
       以对V求梯度为例
∂ l o s s ∂ V = ∑ t = 1 T ∂ L t ∂ V = ∑ t = 1 T ∂ L t ∂ y t ^ ⋅ ∂ y t ^ ∂ h t ⋅ ∂ h t ∂ V ( 1 ) \frac{\partial loss}{\partial V}=\sum_{t=1}^{T}\mathord{\frac{\partial L_{t}}{\partial V}}=\sum_{t=1}^{T}\mathop{\frac{\partial L_{t}}{\partial\hat{y_{t}}}\cdot\frac{\partial\hat{y_{t}}}{\partial {h_t}}\cdot\frac{\partial h_{t}}{\partial V}} \quad(1) Vloss=t=1TVLt=t=1Tyt^Lthtyt^Vht(1)
       由公式 h t = σ ( x t U + h t − 1 V + b h ) h_{t}=\sigma(x_{t}U+h_{t-1}V+b_{h}) ht=σ(xtU+ht1V+bh) ∂ h t ∂ V \frac{\partial h_{t}}{\partial V} Vht单独进行展开,可得
∂ h t ∂ V = ∑ k = 1 t ∂ h t ∂ h k ⋅ ∂ h k ∂ V = ∑ k = 1 t ( ∏ j = k + 1 t h j h j − 1 ) ⋅ ∂ h t ∂ V ( 2 ) \frac{\partial h_{t}}{\partial V}=\sum_{k=1}^{t}\mathop{\frac{\partial{h_{t}}}{\partial {h_k}}\cdot\frac{\partial h_{k}}{\partial V}} =\sum_{k=1}^{t}\mathbin{(\prod_{j=k+1}^{t}\mathbin{\frac{h_{j}}{h_{j-1}})}\cdot\frac{\partial h_{t}}{\partial V}} \quad(2) Vht=k=1thkhtVhk=k=1t(j=k+1thj1hj)Vht(2)

       将(2)式代入(1)式,得
∂ l o s s ∂ V = ∑ t = 1 T ∂ L t ∂ y t ^ ⋅ ∂ y t ^ ∂ h t ∑ k = 1 t ( ∏ j = k + 1 t h j h j − 1 ) ⋅ ∂ h t ∂ V \frac{\partial loss}{\partial V}=\sum_{t=1}^{T}\mathop{\frac{\partial L_{t}}{\partial\hat{y_{t}}}\cdot\frac{\partial\hat{y_{t}}}{\partial {h_t}}\sum_{k=1}^{t}\mathbin{(\prod_{j=k+1}^{t}\mathbin{\frac{h_{j}}{h_{j-1}})}\cdot\frac{\partial h_{t}}{\partial V}}} Vloss=t=1Tyt^Lthtyt^k=1t(j=k+1thj1hj)Vht

       同理可得对U的梯度
∂ l o s s ∂ U = ∑ t = 1 T ∂ L t ∂ y t ^ ⋅ ∂ y t ^ ∂ h t ∑ k = 1 t ( ∏ j = k + 1 t h j h j − 1 ) ⋅ ∂ h t ∂ U \frac{\partial loss}{\partial U}=\sum_{t=1}^{T}\mathop{\frac{\partial L_{t}}{\partial\hat{y_{t}}}\cdot\frac{\partial\hat{y_{t}}}{\partial {h_t}}\sum_{k=1}^{t}\mathbin{(\prod_{j=k+1}^{t}\mathbin{\frac{h_{j}}{h_{j-1}})}\cdot\frac{\partial h_{t}}{\partial U}}} Uloss=t=1Tyt^Lthtyt^k=1t(j=k+1thj1hj)Uht

第三章 传统RNN模型的缺陷和LSTM的提出

3.1 RNN的梯度消失和梯度爆炸

       在计算t时刻损失产生的梯度时,必须回溯之前所有时刻的信息。
       但是我们会发现一个问题,即最后要对所有时刻的梯度进行累加。而每个时刻都是在后一个时刻的基础上进行累乘的结果。
       若累乘的次数过于庞大,每次都连续乘一个小于1的数字,就会导致最终结果趋近于0,即梯度消失。
       反之,当每次都连续乘一个大于1的数字,就会导致最终的结果趋近于无穷,即梯度爆炸。

3.2 RNN缺陷的解决方案

       为了克服梯度爆炸和梯度消失问题,最直观的想法就是使 h j h j − 1 = 1 \frac{h_{j}}{h_{j-1}}=1 hj1hj=1。梯度裁剪,通过把沿梯度下降方向的步长限制在一个范围内,解决了梯度爆炸的问题,但梯度消失的问题仍难以解决。1997年,Hochreiter和Schmidhuber首先提出了LSTM的网络结构,通过CEC(constant error carrousel)单元,控制其结果为0或接近于1,解决了传统RNN的这一缺陷。

第四章 LSTM的理论推导

4.1 LSTM的基本结构

       LSTM的基本结构如下图所示:
图源wiki

       LSTM由遗忘门、输入门、输出门和细胞状态组成。

4.2 LSTM的遗忘门

       遗忘门能够以一定的概率控制是否遗忘上一层的细胞状态。其数学表达式为:
f t = σ ( W f ⋅ [ x t , h t − 1 ] + b f ) f_{t}=\sigma(W_{f}\cdot [x_{t},h_{t-1}]+b_{f}) ft=σ(Wf[xt,ht1]+bf)
       其中激活函数 σ \sigma σ为sigmoid,为该线性关系中的偏置。上一时刻的输出和本时刻的输入作为该函数的输入,通过激活函数,得到遗忘门。
       sigmoid函数的输出在[0,1]之间,表示细胞状态中被保留信息的概率值。输出越接近0,代表被遗忘的信息越多;输出越接近于1,代表被遗忘的信息越少。

4.3 LSTM的输入门

       输入门由两部分组成,能处理当前时刻的输入,进行数据加强,并对细胞状态进行更新。
       第一部分会进行两个操作。
       1.为忽略因子,能够决定删除哪些部分,其数学表达式为:
i t = σ ( W i ⋅ [ x t , h t − 1 ] + b i ) i_{t}=\sigma(W_{i}\cdot [x_{t},h_{t-1}]+b_{i}) it=σ(Wi[xt,ht1]+bi)
       2.使用激活函数tanh创建了一个新的输入,其数学表达式为:
a t = t a n h ( W f ⋅ [ x t , h t − 1 ] + b a ) a_{t}=tanh(W_{f}\cdot [x_{t},h_{t-1}]+b_{a}) at=tanh(Wf[xt,ht1]+ba)
       第二部分,对细胞状态进行更新。将和相乘,遗忘掉不必要的信息,再将新的输入值与,完成从到的更新,其数学表达式为:
C t = f t ⊙ C t − 1 + i t ⊙ a t C_{t}=f_{t} \odot C_{t-1}+i_{t} \odot a_{t} Ct=ftCt1+itat

4.4 LSTM的输出门

       输出门会基于刚更新的细胞状态进行输出,其表达式为
o t = σ ( W o ⋅ [ x t , h t − 1 ] + b o ) o_{t}=\sigma(W_{o}\cdot [x_{t},h_{t-1}]+b_{o}) ot=σ(Wo[xt,ht1]+bo)
h t = o t ⊙ t a n h ( C t ) h_{t}=o_{t} \odot tanh(C_{t}) ht=ottanh(Ct)
       输出门由上一时刻的输出和本时刻的输入,以及细胞状态两部分组成。首先使用sigmoid函数求得,来确定输出细胞状态的哪一部分。再通过tanh函数处理细胞状态,并与,输出这一时刻的结果
       最后再进行预测输出
y t ^ = σ ( V ⋅ h t ] + b y ) \hat{y_{t}}=\sigma(V\cdot h_{t}]+b_{y}) yt^=σ(Vht]+by)

LSTM的反向传播

       待填坑

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值