Basic RNN、LSTM的前向传播和反向传播详细解析
Basic RNN、LSTM由于它们独特的架构,被大量应用在自然语言处理和序列模型的任务上。通过它们自身特殊的结构,可以记住之前的输入中的部分内容和信息,并对之后的输出产生影响。
- 本文主要针对 :对RNN和LSTM有一定基础了解,但是对公式推导还不是完全掌握的童鞋(尤其是lstm的反向传播部分),欢迎各位批评指正~
- 由于markdown编辑公式太麻烦了,所以公式也都是本地编辑之后的截图,有不正确的地方欢迎指正
Basic RNN架构简介
整体架构
模型的整体结构如下图所示,输入的是序列x、输出y,长度为Tx。当然还有针对输入输出不相等的RNN结构,这里只是为了详解RNN的公式推导,特别是反向传播的推导,所以不再赘述。
上图其实是沿时间轴展开的RNN模型,其实图中所有的RNN-cell都共用一套参数,每个cell输入当前时间点的输入x < t >和前一个cell输出的a < t-1 >,得到当前cell的输出a < t >和y < t >。
BasicRNN 前向传播
- 现在我们单独对每个cell进行公式推导,最终整个模型的公式其实就是单个cell的循环调用。
- 下图是单个cell的具体结构图,以及前向传播的公式,非常的简洁明了
Figure 2: Basic RNN cell
BasicRNN 反向传播
针对前面介绍的每个cell前向传播图和公式,我们能很快的写出针对每个cell的反向传播公式:
由前向传播的单个cell图,根据梯度反向传播易知。当前cell的 ∂J∂a<t> ∂ J ∂ a < t > 由两部分构成:
- 当前cell的输出 ŷ y ^ < t >与真实标签代入损失函数,通过损失函数对a< t >求导得到的梯度da< t >1
- 输入到下一个cell的a< t >传回的梯度da< t >2
公式推导前我们还需要知道 ∂tanh(x)∂x=1−(tanh(x))2 ∂ t a n h ( x ) ∂ x = 1 − ( t a n h