1 网络结构
2 实现代码
2.1 RNNCell
import torch
#parameters
batch_size=1
seq_len=3
input_size=4
hidden_size=2
#construction of RNNCell
cell=torch.nn.RNNCell(input_size=input_size,hidden_size=hidden_size)
dataset=torch.randn(seq_len,batch_size,input_size) #(seq,batch,features)
hidden=torch.zeros(batch_size,hidden_size) #(batch,hidden)
for idx,input in enumerate(dataset):
print('='*10,idx,'='*10)
print('Input size:',input.shape)
hidden=cell(input,hidden)
print('Output size:',hidden.shape)
print(hidden)
2.2 RNN
import torch
#parameters
batch_size=1
seq_len=3
input_size=4
hidden_size=2
num_layers=1
#construction of RNNCell
cell=torch.nn.RNN(input_size=input_size,hidden_size=hidden_size,
num_layers=num_layers)
inputs=torch.randn(seq_len,batch_size,input_size) #(seq,batch,features)
hidden=torch.zeros(num_layers,batch_size,hidden_size) #(batch,hidden)
out,hidden=cell(inputs,hidden)
print('Output size:',out.shape) #(seq_len,batch_size,hidden_size)
print('Output:',out)
print('Hidden size:',hidden.shape) #(num_layers,batch_size,hidden_size)
print('Hidden:',hidden)
3 运行结果
3.1 RNNCell
3.2 RNN