inputs = torch.randn(3, 4, 16)
rnn = torch.nn.RNN(input_size=16, hidden_size=6, num_layers=20)
h0 = torch.randn(20, 4, 6)
outputs, _ = rnn(inputs, h0)
print(outputs)
torch.RNN使用
最新推荐文章于 2024-05-07 21:02:09 发布
inputs = torch.randn(3, 4, 16)
rnn = torch.nn.RNN(input_size=16, hidden_size=6, num_layers=20)
h0 = torch.randn(20, 4, 6)
outputs, _ = rnn(inputs, h0)
print(outputs)