上篇文章我们已经学习了循环神经网络的原理,并指出RNN存在严重的梯度爆炸和梯度消失问题,因此很难处理长序列的数据。本篇文章,我们将学习长短期记忆网络(LSTM,Long Short Term Memory),看LSTM解决RNN所带来的梯度消失和梯度爆炸问题。
1.从RNN到LSTM
RNN模型具有如下所示的结构,其中每个索引位置t都有一个隐藏状态 h ( t ) h^{(t)} h(t)。
如果省略每层的 o ( t ) , L ( t ) , y ( t ) o^{(t)},L^{(t)},y^{(t)} o(t),L(t),y(t),则RNN模型可以简化到如下所示的结构。其中隐藏状态的 h ( t ) h^{(t)} h(t)由 x ( t ) x^{(t)} x(t)和 h ( t − 1 ) h^{(t-1)} h(t−1)得到,得到 h ( t ) h^{(t)} h(t)后可用于计算当前层的模型损失和下一层的 h ( t + 1 ) h^{(t+1)} h(t+1)。
为解决梯度消失的问题,大牛们针对RNN序列索引位置t的隐藏结构作出相应改进,进而提出LSTM模型。其中LSTM模型有多种形式,下面我们以最常见的LSTM模型为例进行讲解。
2.LSTM模型结构
LSTM模型除了和RNN模型具有相同的隐藏状态 h ( t ) h^{(t)} h(t)外,还增加了新的隐藏状态 C ( t ) C^{(t)} C(t),如下图中横线所示。新增加的隐藏状态称为细胞状态(Cell State),记为 C ( t ) C^{(t)} C(t)。
除了细胞状态外,LSTM中还多了很多奇怪的结构,称之为门控结构(Gate)。针对每个序列索引位置t,门控结构一般包含遗忘门、输入门和输出门,下面来看看门控结构和细胞状态的结构。
2.1 LSTM之遗忘门
**遗忘门(forget gate)**是以一定的概率控制是否遗忘上一层的隐藏细胞状态,遗忘门的结构如下所示。
输入是上一序列的隐藏状态 h ( t − 1 ) h^{(t-1)} h(t−1)和本序列数据 x ( t ) x^{(t)} x(t),通过一个激活函数(一般是sigmoid),得到遗忘门的输出 f ( t ) f^{(t)} f(t)。由于sigmoid的输出 f ( t ) f^{(t)} f(t)在[0,1]之间,因此这里的 f ( t ) f^{(t)} f(t)代表遗忘上一层隐藏细胞状态的概率,数学表达式如下所示。其中 W f , U f , b f W_{f},U_{f},b_{f} Wf,Uf,bf为线性关系的系数和偏倚, σ \sigma σ为sigmoid激活函数。
f ( t ) = σ ( W f h ( t − 1 ) + U f x ( t ) + b f ) f^{(t)} = \sigma(W_{f}h^{(t-1)} + U_{f}x^{(t)} + b_{f}) f(t)=σ(Wfh(t−1)+Ufx(t)+bf)
2.2 LSTM之输入门
**输入门(input gate)**负责处理当前序列位置的输入,输入门的结构如下所示。
输入门由两部分组成,第一部分使用sigmoid激活函数,输出为 i ( t ) i^{(t)} i(t),第二部分使用tanh激活函数,输出为 a ( t ) a^{(t)} a(t),两者的结果后面会用于相乘后更新细胞状态。 i ( t ) i^{(t)} i(t)和 a ( t ) a^{(t)} a(t)数学表达式如下所示,其中 W i , U i , b i , W a , U a , b a W_{i},U_{i},b_{i},W_{a},U_{a},b_{a} Wi,Ui,bi,Wa,Ua,ba为线性关系的系数和偏倚, σ \sigma σ为sigmoid激活函数。
i ( t ) = σ ( W i h ( t − 1 ) + U i x ( t ) + b i ) i^{(t)} = \sigma(W_{i}h^{(t-1)} + U_{i}x^{(t)} + b_i) i(t)=σ(Wih(t−1)+Uix(t)+bi)
a ( t ) = tanh ( W a h ( t − 1 ) + U a x ( t ) + b a ) a^{(t)} = \tanh(W_{a}h^{(t-1)} + U_{a}x^{(t)} + b_a) a(t)=tanh(Wah(t−1)+Uax(t)+ba)
2.3 LSTM之细胞状态更新
研究LSTM输出门之前,我们先看一下LSTM细胞状态的更新,其中遗忘门和输入门的结果都作用于细胞状态 C ( t ) C^{(t)} C(t)。
细胞状态 C ( t ) C^{(t)} C(t)由两部分组成,第一部分是 C ( t − 1 ) C^{(t-1)} C(t−1)和遗忘门输出 f ( x ) f^{(x)} f(x)的乘积,第二部分是输入门的 i ( t ) i^{(t)} i(t)和 a ( t ) a^{(t)} a