import torch
import torch.nn as nn
单层lstm
lstm = nn.LSTM(input_size=100, hidden_size=200, bidirectional=True, batch_first=True)
a = torch.randn(32, 512, 100)
out, (h, c) = lstm(a)
print(out.shape)
print(h.shape)
print(out[0, -1, :200] == h[0, 0, :])
print(out[0, 0, 200:] == h[1, 0, :])
多层lstm
lstm = nn.LSTM(input_size=100, hidden_size=200, num_layers=3, bidirectional=True, batch_first=True)
a = torch.randn(32, 512, 100)
out, (h, c) = lstm(a)
print(out.shape)
print(h.shape)
print(out[0, -1, :200] == h[-2, 0, :])
print(out[0, 0, 200:] == h[-1, 0, :])