LSTM输入结构

本文详细介绍了LSTM的输入、输出格式以及在PyTorch中的定义。内容包括LSTM的数据立方体结构,LSTM模型参数解释,如input_size、hidden_size、num_layers等,并讨论了batch_first参数和双向LSTM的特点。此外,还阐述了如何准备LSTM的数据输入,LSTM的初始状态h0和c0,以及输出的tuple结构。最后,探讨了LSTM与其他网络(如全连接层)组合的应用场景。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

https://zhuanlan.zhihu.com/p/139617364

 

为了更好理解LSTM结构,必须理解LSTM的数据输入情况。仿照3通道图像的样子,在加上时间轴后的多样本的多特征的不同时刻的数据立方体如下图所示:

三维数据立方体

 

右边的图是我们常见模型的输入,比如XGBOOST,lightGBM,决策树等模型,输入的数据格式都是这种(N*F)的矩阵,而左边是加上时间轴后的数据立方体,也就是时间轴上的切片,它的维度是(N*T*F),第一维度是样本数,第二维度是时间,第三维度是特征数,如下图所示:

这样的数据立方体很多,比如天气预报数据,把样本理解成城市,时间轴是日期,特征是天气相关的降雨风速PM2.5等,这个数据立方体就很好理解了。在NLP里面,一句话会被embedding成一个矩阵,词与词的顺序是时间轴T,索引多个句子的embedding三维矩阵如下图所示:

 

pytorch中定义的LSTM模型

pytorch中定义的LSTM模型的参数如下

class torch.nn.LSTM(*args, **kwargs)
参数有:
    input_size:x的特征维度
    hidden_size:隐藏层的特征维度
    num_layers:lstm隐层的层数,默认为1
    bias&
### LSTM 输入格式 对于LSTM模型,在TensorFlow和PyTorch中的输入格式存在差异。在PyTorch中,当`batch_first=False`时,默认情况下LSTM层期望接收的张量形状为`(seq_len, batch_size, input_size)`[^4]。 而在TensorFlow中,由于Keras API已经被集成进来,创建DNN或者RNN(包括LSTM)变得更为简便。不过关于具体的输入格式,通常也是遵循类似的模式,即三维张量形式,但默认顺序可能有所不同[(seq_length, batch_size, feature)],这取决于具体版本以及配置选项[^1]。 为了确保兼容性和易用性,特别是在使用PyTorch定义LSTM网络结构的时候,可以通过设置参数`batch_first=True`来调整输入输出的第一个维度代表批次大小(batch),从而让其更符合直觉上的理解方式——此时输入应该被组织成`(batch_size, seq_len, input_size)`的形式。 #### 设置LSTM模型的输入 下面给出一段简单的代码片段用于展示如何正确地准备并传递数据给基于PyTorch实现的LSTM模块: ```python import torch from torch import nn class MyLSTM(nn.Module): def __init__(self, input_dim=10, hidden_dim=20, n_layers=2, batch_first=True): super(MyLSTM, self).__init__() self.lstm = nn.LSTM(input_dim, hidden_dim, n_layers, batch_first=batch_first) def forward(self, x): out, (hn, cn) = self.lstm(x) return out[-1] # 假设我们有一个批量的数据样本数量为5,时间步长为7,特征数为10 data = torch.randn(5, 7, 10) model = MyLSTM() output = model(data) print(output.shape) ``` 这段代码展示了如何通过指定`batch_first=True`使输入张量的第一维表示批次数目,并且演示了一个基本的前向传播过程。
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值