RNN变体之LSTM和GRU原理
- 传统循环神经网络具有长时依赖性
- 为了解决这个问题出现了循环神经网络的变式,这两种变式都能够很好地解决长时依赖的问题:
- LSTM
- GRU
1.LSTM算法
1.1 基本概念
1.1.1 LSTM
-
LSTM( Long Short Term Memory Networks) 称为长短时记忆网络
-
是传统RNN的变体,比RNN有效捕捉长序列之间的语义关联,缓解梯度消失或爆炸问题。
-
它解决的是短时记忆的问题,只不过这种短时记忆比较长,能在一定程度上解决长时依赖的问题
-
LSTM由三个门来控制,分别是输入门、遗忘门和输出门
- 输入门控制着网络的输入
- 遗忘门控制着记忆单元
- 输出门控制着网络的输出
-
遗忘门是最重要的,遗忘门的作用是决定之前的哪些记忆将被保留,哪些记忆将被去掉
-
由于遗忘门的作用,使得 LSTM 具有了长时记忆的功能,对于给定的任务,遗忘门能够自己学习保留多少以前的记忆
-
LSTM抽象结构
-
LSTM内部结构
1.1.2 Bi-LSTM
- Bi-LSTM是双向LSTM,将LSTM应用2次且不同方向,将两次得到的结果进行拼接作为做种输出
- Bi-LSTM没有改变LSTM内部结构
1.2. LSTM结构参数理解
参数 | 含义 |
---|---|
h t − 1 h_{t-1} ht−1 | t-1时刻网络输出 |
h t h_{t} ht | t时刻网络输出, h t h_{t} ht 取决于当前时刻 t 的记忆状 C t C_{t} Ct和t时刻的输入 x t x_{t} xt、t- 1 时刻的输出 h t − 1 h_{t-1} ht−1 |
C t − 1 C_{t-1} Ct−1 | 上 t-1 时刻网络中的记忆单元 |
C t C_{t} Ct | 下一 时刻的记忆状态 |
C t ~ \widetilde{C_{t}} Ct | 当前时刻的记忆状态 |
f t f_{t} ft | t时刻网络的输入和 t - 1网络的输出 ,得到 t-1 时刻下的衰减系数 |
i t i_{t} it | t时刻网络的输入和 t - 1网络的输出 ,得到 t 时刻下的衰减系数 |
o t o_{t} ot | 输出门衰减系数 |
σ \sigma σ | sigmoid激活函数 |
⊗ \otimes ⊗ | 数据相乘 |
⊕ \oplus ⊕ | 数据拼接相加 |
1.3 LSTM结构分析
-
遗忘门
- 衰减系数f(t)值作用于上一时刻细胞状态上,代表遗忘过去多少信息,
- 衰减系数由x(t),h(t-1)计算得到,所以衰减系数公式代表着当前时刻的输入x(t)和上一时刻的h(t-1)来决定遗忘多少上一时刻的细胞状态的信息
-
输入门
- 当前时刻的数据信息。
- 代表着输入信息需要舍弃多少,得到当前的细胞状态
-
细胞状态更新
- 将遗忘门与输入门数据进行式更新细胞状态
- 将遗忘门与输入门数据进行式更新细胞状态
-
输出门
- 计算输出衰减系数,更新后的细胞状态进行tanh激活,最终得到h(t),得到隐含状态h(t)
- 计算输出衰减系数,更新后的细胞状态进行tanh激活,最终得到h(t),得到隐含状态h(t)
1.4. LSTM公式理解
- 相关公式
f t = σ ( W f [ h t − 1 , x t ] + b f ) i t = σ ( W i [ h t − 1 , x t ] + b i ) C t ~ = t a n h ( W c [ h t − 1 , x t ] + b C ) C t = f t ∗ C t − 1 + i t ∗ C t ~ o t = σ ( W o [ h t − 1 , x t ] + b o ) h t = o t ∗ t a n h ( C t ) \begin{aligned} f_t&= \sigma(W_f[h_{t-1},x_t]+b_f)\\ i_t&= \sigma(W_i[h_{t-1},x_t]+b_i)\\ \widetilde{C_{t}}&= tanh(W_c[h_{t-1},x_t]+b_C)\\ C_t&=f_t*C_{t-1}+i_t*\widetilde{C_{t}}\\ o_t&= \sigma(W_o[h_{t-1},x_t]+b_o)\\ h_t&=o_{t}*tanh(C_t) \end{aligned} ftitCt Ctotht=σ(Wf[ht−1,xt]+bf)=σ(Wi[ht−1,xt]+bi)=tanh(Wc[ht−1,xt]+bC)=ft∗Ct−1+it∗Ct =σ(Wo[ht−1,xt]+bo)=ot∗tanh(Ct) - 参数描述
-
C
t
−
1
C_{t-1}
Ct−1作为上
t-1
时刻网络中的记忆单元,传入t
时刻的网络之后,第一步操作是决定它的遗忘程度。- 衰减系数:将
t
时刻前面的记忆状态乘上0-1
的系数进行衰减,接着加上t
时刻学到的记忆作为更新之后的记忆传出网络,作为t+l
时刻网络的记忆单元 t-1
时刻网络记忆的衰减系数是通过t
时刻网络的输入和t-1
网络的输出来确定的,t
时刻网络记忆也是根据t
时刻网络的输入和t-1
时刻网络的输出得到的,即 f t 与 i t f_t与i_t ft与it的公式一样,但参数不一样
- 衰减系数:将
-
C
t
−
1
C_{t-1}
Ct−1作为上
1.5 LSTM优缺点
- 优点
- 能够有效缓解长序列问题中可能出现的梯度消失或爆炸
- 缺点
- 内部结构相对较复杂,训练效率比传统RNN低很多
- 不能完全解决梯度消失或爆炸问题
2.GRU
2.1 基本概念
2.1.1 GRU算法
- GRU是Gated Recurrent Unit的缩写,也称门控循环单元结构.
- 是传统RNN的变体,比RNN有效捕捉长序列之间的语义关联,缓解梯度消失或爆炸问题。
- GRU比LSTM结构和计算更简单
- GRU与 LSTM 最大的不同在于 GRU 将遗忘门和输入门合成了一个“更新门”。
- 网络不再额外给出记忆状态 C t C_t Ct ,而是将输出结果 h t h_t ht 作为记忆状态不断向后循环传递,网络的输入和输出都变得特别简单
2.1.2 Bi-GRU
- Bi-GRU是双向GRU,将GRU应用2次且不同方向,将两次得到的结果进行拼接作为做种输出
- Bi-GRU没有改变GRU内部结构
2.2. GRU结构参数理解
参数 | 含义 |
---|---|
h t − 1 h_{t-1} ht−1 | t-1时刻网络输出 |
h t h_{t} ht | t时刻网络输出, h t h_{t} ht 取决于当前时刻 t 的记忆状 C t C_{t} Ct和t时刻的输入 x t x_{t} xt、t- 1 时刻的输出 h t − 1 h_{t-1} ht−1 |
x t x_{t} xt | t时刻输入 |
h ~ t \widetilde{h}_t h t | 当前时刻的记忆状态 |
r t r_{t} rt | t时刻网络的输入和 t - 1网络的输出 ,得到 t-1 时刻下的衰减系数 |
z t z_{t} zt | t时刻网络的输入和 t - 1网络的输出 ,得到 t 时刻下的衰减系数 |
σ \sigma σ | sigmoid激活函数 |
⊗ \otimes ⊗ | 数据相乘 |
⊕ \oplus ⊕ | 数据拼接相加 |
- 相关公式
- 输入门与遗忘门
z t = σ ( W z [ h t − 1 , x t ] ) r t = σ ( W r [ h t − 1 , x t ] ) h ~ t = t a n h ( W ∗ [ r t ∗ h t − 1 , x t ] ) \begin{aligned} z_t&= \sigma(W_z[h_{t-1},x_t])\\ r_t&= \sigma(W_r[h_{t-1},x_t])\\ \widetilde{h}_t&= tanh(W*[r_{t} * h_{t-1},x_t]) \end{aligned} ztrth t=σ(Wz[ht−1,xt])=σ(Wr[ht−1,xt])=tanh(W∗[rt∗ht−1,xt]) - 输出门
h t = ( 1 − z t ) ∗ h t − 1 + z t ∗ h ~ t \begin{aligned} h_t&=(1-z_t)*h_{t-1}+z_t*\widetilde{h}_t \end{aligned} ht=(1−zt)∗ht−1+zt∗h t
- 输入门与遗忘门
- 由输出门公式发现,最终结果与
z
t
z_t
zt有较大关系
- 当 z t z_t zt接近于0,说明模型更关注前面信息,不关注当前的信息
- 当 z t z_t zt接近于0.5,说明模型关注当前信息,也关注前面的信息
- 当 z t z_t zt接近于1,说明模型更关注当前信息,不关注前面的信息
2.3.GRU优缺点
- 优点
- 有效捕捉长序列之间的语义关联,缓解梯度消失或爆炸问题。
- 效果都优于传统RNN且计算复杂度相比LSTM要小
- 缺点
- 不能完全解决梯度消失或爆炸问题
- 不可并行计算,在数据量和模型体量的增大,是未来发展的关键瓶颈