TF2 RNN篇之GRU原理
LSTM 具有更长的记忆能力,在大部分序列任务上面都取得了比基础的RNN 模型更好的性能表现,更重要的是,LSTM 不容易出现梯度弥散现象。但是LSTM 结构相对较复杂,计算代价较高,模型参数量较大。因此,科学家们尝试简化LSTM 内部的计算流程,特别是减少门控数量。研究发现,遗忘门是LSTM 中最重要的门控 [2],甚至发现只有遗忘门的简化版网络在多个基准数据集上面优于标准LSTM 网络。在众多的简化版LSTM中,门控循环网络(Gated Recurrent Unit,简称GRU)是应用最广泛的RNN 变种之一。GRU把内部状态向量和输出向量合并,统一为状态向量 ,门控数量也减少到2 个:复位门(Reset Gate)和更新门(Update Gate),如图所示。
下面我们来分别介绍复位门和更新门的原理与功能。
复位门
复位门用于控制上一个时间戳的状态 h 𝑡−1进入GRU 的量。门控向量𝒈𝑟由当前时间戳
输入𝒙𝑡和上一时间戳状态 h 𝑡−1变换得到,关系如下:
其中𝑾𝑟和𝒃𝑟为复位门的参数,由反向传播算法自动优化,𝜎为激活函数,一般使用
Sigmoid 函数。门控向量𝒈𝑟只控制状态 h 𝑡−1,而不会控制输入𝒙𝑡:
当𝒈𝑟 = 0时,新输入全部来自于输入𝒙𝑡,不接受 h 𝑡−1,此时相当于复位 h 𝑡−1。当𝒈𝑟 = 1时, h 𝑡−1和输入𝒙𝑡共同产生新输入,如图 所示
更新门
更新门用控制上一时间戳状态 𝑡−1和新输入
h
^
t
−
1
\hat{h}_{t-1}
h^t−1对新状态向量 𝑡的影响程度。更新门控向量𝒈𝑧由
得到,其中𝑊𝑧和𝒃𝑧为更新门的参数,由反向传播算法自动优化,𝜎为激活函数,一般使用Sigmoid 函数。𝒈𝑧用与控制新输入
h
^
t
−
1
\hat{h}_{t-1}
h^t−1信号,1 − 𝒈𝑧用于控制状态 h 𝑡−1信号
h
t
=
(
1
−
g
z
)
h
t
−
1
+
g
t
h
^
t
−
1
h_t = (1-g_z)h_{t-1} + g_t\hat{h}_{t-1}
ht=(1−gz)ht−1+gth^t−1
可以看到,
h
^
t
−
1
\hat{h}_{t-1}
h^t−1和
h
t
−
1
h_{t-1}
ht−1对 h 𝑡的更新量处于相互竞争、此消彼长的状态。当更新门𝒈𝑧 = 0时,h 𝑡全部来自上一时间戳状态 h 𝑡−1;当更新门𝒈𝑧 = 1时, 𝑡全部来自新输入
h
^
t
−
1
\hat{h}_{t-1}
h^t−1。
GRU 实战
GRU的使用方法和LSTM类似只需要把SimpleRNN的代码稍作修改即可,SimpleRNN的代码在博客最上方的导览中有
GRUCell
和SimpleRNN一样,Tensorflow提供了两个网络层的表达GRU方式,一个是GRUCell一个是GRU
self.rnn_cell0 = laGRUyers.SimpleRNNCell(units, dropout=0.5)
self.rnn_cell1 = layers.SimpleRNNCell(units, dropout=0.5)
==>
self.rnn_cell0 = layers.GRUCell(units,dropout=0.5)
self.rnn_cell1 = layers.LSTMCell(units,dropout=0.5)
只要改这俩行代码即可
GRU层
如果是GRU层就更简单了
self.rnn = keras.Sequential([
layers.GRU(units,dropout=0.5,return_sequences=True,unroll=True),
layers.GRU(units,dropout=0.5,unroll=True)
])
构建时候把SimpleRNN改成GRU就ok了
参考书籍: TensorFlow 深度学习 — 龙龙老师