前言
本节继续学习循环神经网络
- GRU
- LSTM
- 双向循环神经网络
1、门控循环单元(GRU)
- 当时间步数较大或者时间步较小时,循环神经⽹络的梯度较容易出现衰减或爆炸
- 裁剪梯度可以应对梯度爆炸,但无法解决梯度衰减
- 门控循环神经网络(gated recurrent neural network)的提出,正是为了更好地捕捉时间序列中时间步距离较大的依赖关系
- 有GRU和LSTM两种
GRU引入重置门和更新门
重置门和更新门
- 重置⻔有助于捕捉时间序列⾥短期的依赖关系
- 更新⻔有助于捕捉时间序列⾥⻓期的依赖关系
候选隐藏状态 - 当前时间步重置⻔的输出与上⼀时间步隐藏状态做按元素乘法(符号为⊙)
- 将按元素乘法的结果与当前时间步的输⼊连结
- 再通过含激活函数tanh的全连接层计算出候选隐藏状态
隐藏状态
- 上一时间步的隐藏状态和当前候选隐藏状态做个组合
实现
import d2lzh as d2l
from mxnet import nd
from mxnet.gluon import rnn
"""实现GRU"""
# 数据,周杰伦歌词
(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()
# 模型参数
num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
ctx = d2l.try_gpu()
def get_params():
def _one(shape):
return nd.random.normal(scale=0.01, shape=shape, ctx=ctx)
def _three():
return (_one((num_inputs, num_hiddens)),
_one((num_hiddens, num_hiddens)),
nd.zeros(num_hiddens, ctx=ctx))
W_xz, W_hz, b_z = _three() # 更新门参数
W_xr, W_hr, b_r = _three() # 重置门参数
W_xh, W_hh, b_h = _three() # 候选隐藏状态参数
# 输出层参数
W_hq = _one((num_hiddens, num_outputs))
b_q = nd.zeros(num_outputs, ctx=ctx)
# 附上梯度
params = [W_xz