import torch
from torch import nn
import time
#rnn
model_rnn = nn.RNN(
input_size= 4,
hidden_size= 128,
bias=True,
num_layers= 1,
bidirectional= False,
)
x = torch.randn(size=(31,8,4))
output , hn = model_rnn(x)
print("output:",output.size())
print("hn:",hn.size())
#lSTM
model_lstm = nn.LSTM(
input_size= 4,
hidden_size= 256,
num_layers=1,
bidirectional= False,
)
print("lstm param number:",sum([p.numel() for p in model_lstm.parameters() if p.requires_grad]))
x = torch.randn(size=(365,8,4))
start_time = time.time()
output,(hn,cn) = model_lstm(x)
end_time = time.time()
time_cousuming = round((end_time - start_time) * 1000)
print("lstm time cousuming:",time_cousuming)
print("lstm output shape:",output.size(),hn.size(),cn.size())
#GRU
model_GRU = nn.GRU(
input_size= 4,
hidden_size= 256,
num_layers= 1,
bidirectional= False,
)
print("GRU param number:",sum([p.numel() for p in model_GRU.parameters() if p.requires_grad]))
x = torch.randn(size=(365,8,4))
start_time = time.time()
output,hn = model_GRU(x)
end_time = time.time()
time_cousuming = round((end_time - start_time) * 1000)
print("GRU time cousuming:",time_cousuming)
print("GRU output shape:",output.size(),hn.size())
#双向RNN
model_bi_GRU = nn.GRU(
input_size=4,
hidden_size= 256,
num_layers=2,
bidirectional= True,
)
x = torch.randn(size=(365,8,4))
start_time = time.time()
output,hn = model_bi_GRU(x)
end_time = time.time()
time_cousuming = round((end_time - start_time) * 1000)
print("time consuming:",time_cousuming)
print("bi GRU output shape:",output.size())
print("bi GRU hn shape:",hn.size()) #torch.Size([4, 8, 256]) [num_layers * num_directions,batch,hidden_size]
输出:
output: torch.Size([31, 8, 128])
hn: torch.Size([1, 8, 128])
lstm param number: 268288
lstm time cousuming: 65
lstm output shape: torch.Size([365, 8, 256]) torch.Size([1, 8, 256]) torch.Size([1, 8, 256])
GRU param number: 201216
GRU time cousuming: 60
GRU output shape: torch.Size([365, 8, 256]) torch.Size([1, 8, 256])
time consuming: 318
bi GRU output shape: torch.Size([365, 8, 512])
bi GRU hn shape: torch.Size([4, 8, 256])