BILSTM代码

import torch
import torch.nn as nn

class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1, dropout=0):
        super(BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Define the bidirectional LSTM layer
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, 
                            batch_first=True, dropout=dropout, bidirectional=True)
        
        # Define the fully connected layer
        self.fc = nn.Linear(hidden_size * 2, output_size)
    
    def forward(self, x):
        # Initialize hidden state and cell state
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        
        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))
        
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

# Example usage
input_size = 10
hidden_size = 20
output_size = 5
seq_length = 7
batch_size = 3

model = BiLSTM(input_size, hidden_size, output_size)

# Create a sample input tensor
sample_input = torch.randn(batch_size, seq_length, input_size)

# Get the model output
output = model(sample_input)
print(output)

代码解释

  1. 初始化函数 __init__

    • input_size 是输入特征的维度。
    • hidden_size 是LSTM隐藏层的维度。
    • output_size 是输出的维度。
    • num_layers 是LSTM层数。
    • dropout 是dropout的概率。
  2. LSTM层

    • self.lstm 定义了一个双向LSTM层(bidirectional=True)。
  3. 全连接层

    • self.fc 是一个全连接层,用于将LSTM的输出映射到所需的输出维度。
  4. 前向传播函数 forward

    • 初始化LSTM的隐藏状态和细胞状态。
    • 通过LSTM层进行前向传播。
    • 使用全连接层处理LSTM的最后一个时间步的输出。
  5. 使用示例

    • 定义了模型的输入和输出维度。
    • 创建了一个BiLSTM实例。
    • 生成了一个随机输入张量并获取模型输出。
  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值