本文记录LSTNet的原理、代码解读及相关问题。如有错误还请大家指出。
作者源码见:GitHub - laiguokun/LSTNet
多步预测见:基于Pytorch的LSTNet多步时间预测(附网络代码)-CSDN博客
准备工作:
运行需要在目录中创建data文件夹及save文件夹,并将数据集放入data文件夹中
然后在main.py中邮件设置运行环境变量
将 --data data/solar_AL.txt --save save/solar_AL.pt --hidSkip 10 --output_fun Linear 输入保存
接着我们就可以开始愉快的Debug了!
数据定义
首先看一下数据集的定义:
代码中作者自定义了数据集类(Data_utility):
这里纠正一个错误,图中n和m的意思说反了
def _normalized 和 def _split 分别是数据归一化和数据划分函数,此处不再赘述。
下面的_batchify函数是关键,该函数定义了数据集构成原理。
为什么说这里关键,因为这里决定了代码的数据输入与输出的格式,下面我们一步一步来看:
Debug可知,这里的X,Y的shape如下所示,其中X的shape为(数据总数量,时间序列窗口大小,一条数据包含多少个值),强调一下这里的137是因为原始数据每一个时间点都由137个光伏电站的值组成,所以本代码也是同时进行输出137个电站的结果。
可以注意到这里的Y怎么只有两维?真相只有一个,作者只是进行了单步(One-Step)预测,也就是说,作者只使用168条历史值去预测未来一个值,看到这里你可能会像我一样想到,(吐槽)卧槽这不是扯淡么?我预测一个值有什么意义?然而真相就是在TimeSeries系列的文章中,基本上都是只做单步预测。为什么?因为单步预测很准,误差也很低,更容易发论文,然而单步预测在实际工程应用中可以说没有任何意义与作用。
那么问题来了,如果我们想进行多步预测该怎么做?原来的Y(31357,137)等价于(31357,1,137),而中间多出来的这个1,就是我们的输出步长,在进行多步预测时,需要将Y也就是Target的shape改变为3维的形式,比如我想预测未来5个值,那么这里的shape应该为(31357,5,137)
接着,进行无数次的循环后,就构成了我们所需要的训练集。
网络核心流程
卷积
下面进入训练步骤:
我们的X(输入值)的shape如下:128为Batch_size
X小伙计经过了view转换为进行卷积所需要的形状C(128,1,168,137)->(batch size,通道数,输入的高度,输入的宽度):
这里来看一下self.conv1是个什么牛马?:
这里的kernel_size = (6,137) , hidC为卷积后的隐藏单元数量,通道数为1。
卷积后,C小伙计的形状就变成了(128,100,163,1),这里进行一下解释:128为Batch size,不改变,100为隐藏单元的数量(卷积核数量),所以卷积后变为100,卷积的输出大小计算公式如下:
经过卷积操作后的输出结构为(Batch Size,输出通道数,输出的高度,输出的宽度),
这里所使用的卷核心为(6,137),不是常规的N*N卷积核。
因为第四维=1,所以接着通过torch.squeeze()函数对C小伙计的第三维和第四维进行压缩,特征形状变为了(128,100,163)
此时,特征的卷积操作就结束了,接着进入RNN(循环神经网络)流程
RNN
这里说一下为什么要把第三维转变到第一维:前面说过经过上一个卷积步骤的操作后,特征形状为(128,100,163)即(batch _size,channels,features),这里的channels实际上为隐藏单元的数量,我们的LSTM以及GRU单元需要的输入特征形状为(sequence_length,batch,input_size),因此上一步的163也就等同于sequence_length即序列长度。转换后r的shape为(163,128,100)
这里的r(163,128,100)经过GRU单元得到了两个输出,首先看一下GRU的原理图:
由上图可知,GRU单元同样也是需要两个输入分别是h0和Input,Input即为上面提到的r,h0在没有提供时默认为0,因为在初始状态下GRU单元没有可以接收的信息。特征在经过GRU单元后得到了Output和hn两个值,其中Output的shape与Input相同,均为(seq_len, batch, input_size),hn的值实际上与Output中最后一个输出h4相同,hn的shape为(num_layers * num_directions, batch, hidden_size),这里第一维num_layers * num_directions即GRU单元数量乘以GRU方向(单向或双向),因此r(163,128,100)经过GRU单元便得到了_(163,128,100)和r(1,128,100),输出中我们只需要用到hn,所以Output便用占位符代替。
经过RNN层后最终的输出为(128,100)
循环跳跃层
接下来是Skip-rnn环节:
直接看论文中的图可能不好理解,我们从代码入手,当Skip>0时,跳过功能便会启用。
首先分析一下self.pt:
理解pt是什么东西是关键,pt最合理的解释应该是卷积后的特征所包含的周期数
其中P为时间序列窗口大小, Ck为卷积核大小,skip默认为24。这里有个问题,view()里不能出现float类型的数值,然而这里的pt经过求解后值为6.75,运行到这里会报错,作者的代码在这里有明显错误。因此我们需要将该行代码改为如下所示:
取整后pt值变为6,相应的其他与pt相关的数值也有所改变。
我在反复阅读论文和代码后,认为pt实际上代表数据包含多少个周期!所以周期skip应该根据数据格式来进行设置,比如一天有24条数据那就是24,一天有96条数据那就应该设置为96,作者源码所给出的Window Size为168,skip为24,因为该数据集一天有24条数据,所以168刚好是7天的数据,所以pt(多少个周期)也就为7了。然而,根据LSTNet的结构我们可以知道数据首先是需要经过CNN,再进入RNN,而经过CNN后的数据长度就会变为(Input size - kernel size +1)即(168-6+1=163),这下问题就显而易见了:163无法被24整除,经过卷积后提取出的特征无法被分为7个周期,所以只能通过下面的这一行代码对数据进行切片,舍弃掉前面无法凑出一个完整周期的数据:
上式等价于 c[:,:,- (6*24) :],经过该行代码,原来的c(128,100,163)被切片为s(128,100,144)
该行代码就是用来将144划分为6*24,此时s的Shape为(128,100,6,24)
接着往下看
这里通过permute()将s的shape转换为(6,128,24,100),下面来讲解一下为什么要这么干?
PS:打开viso,画个图会更好理解一点 :)
如图所示,将每长度为144的向量划分为6个长度为24的向量
因为GRU单元所接收的输入格式为(seq_len,batch,input_size),在这里我们希望将同一时刻的数据送入GRU单元,总共有6个周期(同一时刻有6个值),所以需要将seq_len设置为6,这里的input_size为GRU隐藏单元数量,所以为100不变,那么batch的大小就需要设置为128*24了。
经过上面行代码,s的shape变为(6,3072,100),这正好是GRU单元所期望的数据形式。
与之前的步骤一样,将经过GRU单元后的输出转换为同样的形式(128,240)。
再与之前RNN的输出r进行拼接,最终的输出为(128,340)。到这里,循环跳过层就结束了。
自回归层
自回归通俗来讲就是把历史值当做特征进行预测未来值,加入自回归则是在非线性的神经网络中引入了线性。LSTNet中自回归的操作非常简单,核心就是添加了一个全连接层。
首先看一下自回归层的输入x:
在本测试数据集中,数据最初的形状为(128,168,137),其中每个Batch中的数据为168*137,如下图所示:
其中每一行为在同一时刻下137个光伏厂站的负荷值,168即168个小时等价于7天。
自回归层的第一行代码则是取每一个Batch中最后24个时间点的数据(其中self.hw=24),如下图所示:
与前面的时间序列循环跳过层不同的是,循环跳过层注重的是预测日前7天内的时序关系,而这里的自回归操作则单纯是去寻找预测日前一天的历史值与预测日的值之间的线性关系。
最后两步则是将经过“自回归层”的特征放缩至与经过“卷积-循环跳跃层”后的特征同一维度,并进行拼接,构成了最终的特征。到这里,特征在LSTNet中就完整走完了一轮。经过无数次的迭代后即可训练出模型。
总结
最后做一下总结:
LSTNet是一个时间序列预测的算法框架,其中涉及到了卷积神经网络、循环神经网络、循环跳跃层、注意力机制(在作者开源的代码中未体现)、自回归。
CNN(卷积):提取数据整体特征
RNN(循环):提取数据时序特征
RNN-Skip(循环跳跃):提取数据同一时刻之间的特征 举个栗子:7天(168个小时)中,每天8点的数据(共7个)之间的关系
AR(自回归):引入 “历史 to 未来” 线性关系
流程图如下图所示:
结论:LSTNet开源代码为单步预测,实际应用价值较低。作者开源代码有些许不合理之处(当然,也或许是本人学艺不精,未看懂其中的奥妙,还请大家多多指点),并且有所保留。不过注意力机制这种即插即用的模块也比较容易添加,后续有空改一版带注意力机制的多步预测的LSTNet。
本文至此结束,感谢您的耐心阅读 : )