import torch
import torch.nn as nn
class BiLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers=1, dropout=0):
super(BiLSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
# Define the bidirectional LSTM layer
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
batch_first=True, dropout=dropout, bidirectional=True)
# Define the fully connected layer
self.fc = nn.Linear(hidden_size * 2, output_size)
def forward(self, x):
# Initialize hidden state and cell state
h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
# Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0))
# Decode the hidden state of the last time step
out = self.fc(out[:, -1, :])
return out
# Example usage
input_size = 10
hidden_size = 20
output_size = 5
seq_length = 7
batch_size = 3
model = BiLSTM(input_size, hidden_size, output_size)
# Create a sample input tensor
sample_input = torch.randn(batch_size, seq_length, input_size)
# Get the model output
output = model(sample_input)
print(output)
代码解释
-
初始化函数
__init__
:input_size
是输入特征的维度。hidden_size
是LSTM隐藏层的维度。output_size
是输出的维度。num_layers
是LSTM层数。dropout
是dropout的概率。
-
LSTM层:
self.lstm
定义了一个双向LSTM层(bidirectional=True
)。
-
全连接层:
self.fc
是一个全连接层,用于将LSTM的输出映射到所需的输出维度。
-
前向传播函数
forward
:- 初始化LSTM的隐藏状态和细胞状态。
- 通过LSTM层进行前向传播。
- 使用全连接层处理LSTM的最后一个时间步的输出。
-
使用示例:
- 定义了模型的输入和输出维度。
- 创建了一个BiLSTM实例。
- 生成了一个随机输入张量并获取模型输出。