本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson8/rnn_demo.py
这篇文章主要介绍了循环神经网络(Recurrent Neural Network),简称 RNN。
RNN 常用于处理不定长输入,常用于 NLP 以及时间序列的任务,这种数据一半具有前后关系。
RNN 网络结构如下:
上图的数据说明如下:
- x t x_{t} xt:时刻 t 的输入, s h a p e = ( 1 , 57 ) shape=(1,57) shape=(1,57),表示
(batch_size, feature_dim)
。57 表示词向量的长度。 - s t s_{t} st:时刻 t 的状态值, s h a p e = ( 1 , 128 ) shape=(1,128) shape=(1,128),表示
(batch_size, hidden_dim)
。这个状态值有两个作用:经过一个全连接层得到输出;输入到下一个时刻,影响下一个时刻的状态值。也称为hedden_state
,隐藏层状态信息,记录过往时刻的信息。第一个时刻的 s t s_{t} st 会初始化为全 0 的向量。 - o