第R3周:LSTM-火灾温度预测:3. nn.LSTM() 函数详解

nn.LSTM 是 PyTorch 中用于创建长短期记忆(Long Short-Term Memory,LSTM)模型的类。LSTM 是一种循环神经网络(Recurrent Neural Network,RNN)的变体,用于处理序列数据,能够有效地捕捉长期依赖关系。

语法

torch.nn.LSTM(input_size, hidden_size, num_layers=1, 
              bias=True, batch_first=False, 
              dropout=0, bidirectional=False)

● input_size: 输入特征的维度。
● hidden_size: 隐藏状态的维度,也是输出特征的维度。
● num_layers(可选参数): LSTM 层的数量,默认为 1。
● bias(可选参数): 是否使用偏置,默认为 True。
● batch_first(可选参数): 如果为 True,则输入和输出张量的形状为 (batch_size, seq_len, feature_size),默认为 False,张量的形状为(seq_len, batch_size, feature_dim)。
● dropout(可选参数): 如果非零,将在 LSTM 层的输出上应用 dropout,防止过拟合。默认为 0。
● bidirectional(可选参数): 如果为 True,则使用双向 LSTM,输出维度将翻倍。默认为 False。

示例

import torch
import torch.nn as nn

# 定义一个单向 LSTM 模型
input_size  = 10
hidden_size = 20
num_layers  = 2
batch_size  = 3
seq_len     = 5

lstm = nn.LSTM(input_size, hidden_size, num_layers)

# 构造一个输入张量
input_tensor = torch.randn(seq_len, batch_size, input_size)

# 初始化隐藏状态和细胞状态
h0 = torch.randn(num_layers, batch_size, hidden_size)
c0 = torch.randn(num_layers, batch_size, hidden_size)

# 将输入传递给 LSTM 模型
output, (hn, cn) = lstm(input_tensor, (h0, c0))

print("Output shape:", output.shape)    # 输出特征的形状
print("Hidden state shape:", hn.shape)  # 最后一个时间步的隐藏状态的形状
print("Cell state shape:", cn.shape)    # 最后一个时间步的细胞状态的形状

代码输出

Output shape: torch.Size([5, 3, 20])
Hidden state shape: torch.Size([2, 3, 20])
Cell state shape: torch.Size([2, 3, 20])

注意事项

● input_size 指定了输入数据的特征维度,hidden_size 指定了 LSTM 层的隐藏状态维度,num_layers 指定了 LSTM 的层数。
● LSTM 的输入张量的形状通常是 (seq_len, batch_size, input_size),但如果设置了 batch_first=True,则形状为 (batch_size, seq_len, input_size)。
● LSTM 的输出包括输出张量和最后一个时间步的隐藏状态和细胞状态。
● 可以通过 bidirectional=True 参数创建双向 LSTM,它会将输入序列分别从前向和后向传播,并将两个方向的隐藏状态拼接在一起作为输出。
● 在使用 LSTM 时,通常需要注意输入数据的序列长度,以及是否需要对输入数据进行填充或截断,以保证输入数据的长度是一致的。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值