https://zhuanlan.zhihu.com/p/139617364
为了更好理解LSTM结构,必须理解LSTM的数据输入情况。仿照3通道图像的样子,在加上时间轴后的多样本的多特征的不同时刻的数据立方体如下图所示:
三维数据立方体
右边的图是我们常见模型的输入,比如XGBOOST,lightGBM,决策树等模型,输入的数据格式都是这种(N*F)的矩阵,而左边是加上时间轴后的数据立方体,也就是时间轴上的切片,它的维度是(N*T*F),第一维度是样本数,第二维度是时间,第三维度是特征数,如下图所示:
这样的数据立方体很多,比如天气预报数据,把样本理解成城市,时间轴是日期,特征是天气相关的降雨风速PM2.5等,这个数据立方体就很好理解了。在NLP里面,一句话会被embedding成一个矩阵,词与词的顺序是时间轴T,索引多个句子的embedding三维矩阵如下图所示:
pytorch中定义的LSTM模型
pytorch中定义的LSTM模型的参数如下
class torch.nn.LSTM(*args, **kwargs)
参数有:
input_size:x的特征维度
hidden_size:隐藏层的特征维度
num_layers:lstm隐层的层数,默认为1
bias&