RNN-torch.nn.RNN
RNN的本质就是就是:hidden = tanh(linear([input+hidden]))
调用pytorch的nn.RNN接口,只需要明确一点,input的默认形状是[seqlength, batchsize, feature_size], hidden的形状是[num_layers, batch_size, hidden_size]。返回两个结果:所有hidden(如果num_layers>1,就返回所有hidden[-1]), 以及最后一个hidden。
import torch
import torch.nn as nn
input = torch.randn(5, 3, 10) # seq_length=5, batch_size=3, feature_size=10
hidden0 = torch.randn(2, 3, 20) # num_layers=2, batch_size=3, hidden_size=20
rnn = nn.RNN(input_size=10, hidden_size=20, num_layers=2)
all_hidden, last_hidden = rnn(input, hidden0)
print(all_hidden.size()) # [5, 3, 20] [seq_length, batch_size, hidden_size]]
print(last_hidden.size()) # [2, 3, 20] [num_layers, batch_size, hidden_size]