以下均为单向RNN。
0. RNN模型结构
网上教程的标准RNN结构如下图,其实是有输入层x、隐藏层h和输出层y三层结构的。
但是在Pytorch中定义的RNN,其实是没有y这个输出层的。例如下图中,Pytorch版本的两个输出,output=[h1, h2, h3, h4], hn = h4。如果想要得到输出层y,可以自行加一个全连接层。
1. 初始化RNN
rnn = nn.RNN(input_size, hidden_size, num_layers)
2. RNN的输入
- input:(seq_len, batch_size, input_size)
- h0:(num_layers, batch_s