什么是BasicTS?
BasicTS (Basic Time Series) 一个标准和公平的时间序列预测基准和工具箱。
它收纳了许多深度学习时间序列预测模型,包括单元时间序列预测以及多元时间序列预测中的一些经典或是最新的模型。你可以轻易使用它完成时间序列预测任务。
这是BasicTS的网站:zezhishao/BasicTS: An Standard and Fair Time Series Forecasting Benchmark and Toolkit. (github.com)
什么是时间序列预测?
时间序列是按照一定的时间间隔排列的一组数据,其时间间隔可以是任意的时间单位,如小时、日、周月等。
时间序列预测就是去学习去学习过去时间序列,以预测该时间序列的未来状态。
单元时间序列预测:单元时间序列预测我们又称之为单变量时间序列预测;一般而言,如ARIMA、LSTM都可以完成单元的时间序列预测,这意味着我们默认各个变元之间没有必然的关联,我们只是通过学习普遍的时间依赖,完成时间序列的预测。
多元时间序列预测:与上述的单元时间序列预测相反,我们认为各个变元之间是存在某种意义上的关联的,如交通预测中欧式空间中相邻、功能区相似。金融预测中属于相同的行业、持股人相似等。因此,多元时间序列预测不仅仅要学习时间依赖,还要学习变元之间的空间关系,以及空间与时间的交互关系等等。
时间序列预测一直是数据挖掘领域中的一个很大的话题,在交通、金融、环境、生物科学等众多领域都存在普遍的应用。
BasicTS使用说明-前言
BasicTS的目的旨在帮助您快速应用一些人工智能领域中的一些时间序列预测模型。
本文旨在帮助您快速上手BasicTS,且主要针对本地用户,比如您像我一样使用的是windows11与Pycharm。根据github上的使用说明,会踩一些坑,主要是文件路径上的一些问题。BasicTS的GitHub首页其实是有介绍和使用说明的,本文做了一些本地用户在实际使用中的补充。
使用说明-BasicTS
下载:再贴一下地址:zezhishao/BasicTS: An Standard and Fair Time Series Forecasting Benchmark and Toolkit. (github.com)
本地用户直接下载就好了,无论是GIt的clone或pull、Github上的download zip还是通过GIthub Desktop。文件很小,2M左右,作为benchmark已经是非常小了。
环境要求:requirements.txt里面都有,如果您的电脑里面啥都没有,直接 pip install -r requirements.txt 就行。但我这种本地用户,torch这样重要的环境都是最新版本,只能在数据预处理好后的上手阶段慢慢配置环境,缺什么就下什么。torch会有一点点不匹配的问题,我们后面讲述,其他的库直接下最新的就行,不会遇到什么问题。
数据集下载与预处理:数据集在百度云里面下载,密码:6v0a。压缩包大小170M,解压到../datasets/raw_data/里面,解压完400M出头。百度网盘 请输入提取码百度网盘为您提供文件的网络备份、同步和分享服务。空间大、速度快、安全稳固,支持教育网加速,支持手机端。注册使用百度网盘即可享受免费存储空间https://pan.baidu.com/share/init?surl=0gOPtlC9M4BEjx89VD1Vbw 然后,数据集预处理就开始踩坑了,不太清楚服务器端是否会出现相同的问题,本地用户在运行scripts/data_preparation里面对应的generate_training_data.py文件或是bash all.sh时,会遇到找不到文件路径的问题,以及后续的读取预处理好的数据时,也会有路径问题。
通过检查generate_training_data.py,打印出输入输出的绝对路径,会发现dataset/raw_data/以及dataset/的索引都是在当前目录,当然找不到数据。
修改读取路径dataset/raw_data/为绝对路径,如我的: D:/myfile/BasicTS/datasets/raw_data/, 其中/BasicTS/为存在BasicTS的根目录,scripts、examples等等都在该目录下。
修改保存路径dataset/为examples下的绝对路径,如我的:D:/myfile/BasicTS/examples/datasets/
all.sh里面应该是相同的问题。将需要的数据集预处理完就可以上手测试了。
上手测试:在examples里面的run.py修改parse_args里面的程序,运行谁就把注释删除,把其他的注释掉。由于参数是通过配置文件(cfg)来定义的,您可以按照模型_数据集寻找examples里面的对应py文件,并根据需要修改参数。此外,run.py中的gpu选取改为0,除非您有多张显卡。。。
模型的主要代码在basicts/archs/arch_zoo/里面,可以根据需要查看。
踩坑:GraphWaveNet作为非常经典的多元时间序列模型,毕竟比较早了,我的torch版本为1.13,会遇到conv1d的报错,在arch_zoo/对应的文件下,直接修改为2d(?也许与原本模型不符)
创建您自己的模型:examples里面的MLP给出了您所需要的一些文件,包括cfg的参数配置文件,arch的算法文件以及runner的运行文件。
创建您自己的模型时,将cfg的py文件放到examples的文件夹内,arch的py文件放在basicts/archs/arch_zoo/,cfg的py文件索引的runners文件放在basicts/runners中。
除此以外,自己定义模型时应该会遇到模型名称的报错,因为某些__init__.py文件中缺少您定义的模型名称,遇到报错索引到相应的py地址加上去就行。
其他的loss、metrics都在basicts/里面,需要的自行修改添加。
运行结果保存在examples/checkpoints的对应文件夹下,由于是断点续训,所以重新运行需要把此前的pt参数文件删除。
已有模型的结果在result里面的图片中,您可以对比您的模型与此前的经典模型。
这里贴一段自用的可以快速查看您自己模型的运行效果的代码,根据val中每个epoch的mae最好结果,打印该epoch下horizon为3/6/12的mae、rmse、mape,以及平均的mae、rmse、mape。也就是result中图片所需的结果。并且应用matplotlib做出所有epoch的主要结果。(注意修改path路径)
祝您BasicTS使用愉快!
import matplotlib.pyplot as plt
import numpy as np
path = 'D:/myfile/BasicTS/examples/checkpoints/MTGNN/' + 'training_log_20230423223558.log'
l_mae_3 = []
l_rmse_3 = []
l_mape_3 = []
l_mae_6 = []
l_rmse_6 = []
l_mape_6 = []
l_mae_12 = []
l_rmse_12 = []
l_mape_12 = []
l_epoch = []
l_val_mae = []
l_best_mae = []
l_best_rmse = []
l_best_mape = []
with open(path,'r') as f:
for i in f.readlines():
s = i.split(' ')
print(s)
if 'Epoch' in s:
l_epoch.append(int(s[8]))
if 'horizon' and '3,' in s:
l_mae_3.append(float(s[18][:-1]))
l_rmse_3.append(float(s[21][:-1]))
l_mape_3.append(float(s[24][:-1]))
if 'horizon' and '6,' in s:
l_mae_6.append(float(s[18][:-1]))
l_rmse_6.append(float(s[21][:-1]))
l_mape_6.append(float(s[24][:-1]))
if 'horizon' and '12,' in s:
l_mae_12.append(float(s[18][:-1]))
l_rmse_12.append(float(s[21][:-1]))
l_mape_12.append(float(s[24][:-1]))
if 'val_MAE:' in s:
l_val_mae.append(float(s[13][:-1]))
if '<test>:' in s:
l_best_mae.append(float(s[13][:-1]))
l_best_rmse.append(float(s[15][:-1]))
l_best_mape.append(float(s[17][:-2]))
print(l_epoch)
print(l_best_mae)
n = l_val_mae.index(min(l_val_mae))
print('best_mae3: ',l_mae_3[n],l_rmse_3[n],l_mape_3[n])
print('best_mae6: ',l_mae_6[n],l_rmse_6[n],l_mape_6[n])
print('best_mae12: ',l_mae_12[n],l_rmse_12[n],l_mape_12[n])
print('best_test: ',l_best_mae[n],l_best_rmse[n],l_best_mape[n])
plt.figure(figsize=(8,5),dpi=100)
x = [i+1 for i in range(len(l_best_mae))]
#x = l_epoch
plt.plot(x,l_mae_3,label='MAE3')
plt.plot(x,l_mae_6,label='MAE6')
plt.plot(x,l_mae_12,label='MAE12')
plt.legend()
plt.show()
plt.figure(figsize=(8,5),dpi=100)
#x = [i for i in range(len(l_epoch))]
plt.plot(x,l_rmse_3,label='rmse3')
plt.plot(x,l_rmse_6,label='rmse6')
plt.plot(x,l_rmse_12,label='rmse12')
plt.legend()
plt.show()
plt.figure(figsize=(8,5),dpi=100)
#x = [i for i in range(len(l_epoch))]
plt.plot(x,l_mape_3,label='mape3')
plt.plot(x,l_mape_6,label='mape6')
plt.plot(x,l_mape_12,label='mape12')
plt.legend()
plt.show()
plt.figure(figsize=(8,5),dpi=100)
#x = [i for i in range(len(l_epoch))]
plt.plot(x,l_best_mae,label='mse')
plt.plot(x,l_best_rmse,label='rmse')
plt.plot(x,l_best_mape,label='mape')
plt.legend()
plt.show()
plt.figure(figsize=(8,5),dpi=100)
#x = [i for i in range(len(l_epoch))]
plt.plot(x,l_val_mae,label='mse')
plt.legend()
plt.show()