本文展示了使用时序卷积网络(TCN)进行时间序列预测的全过程,包含详细的注释。整个过程主要包括:数据导入、数据清洗、结构转化、建立TCN模型、训练模型(包括动态调整学习率和earlystopping的设置)、预测、结果展示、误差评估等完整的时间序列预测流程。
本文使用的tcn库在本人上传的资源中,链接为tcn.py
本文使用的数据集在本人上传的资源中,链接为mock_kaggle.csv
import pandas as pd
import numpy as np
import math
from matplotlib import pyplot as plt
from matplotlib.pylab import mpl
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler
from keras import backend as K
from keras.layers import LeakyReLU
from tcn import TCN,tcn_full_summary
from sklearn.metrics import mean_squared_error # 均方误差
from keras.callbacks import LearningRateScheduler
from keras.callbacks import EarlyStopping
from tensorflow.keras import Input, Model,Sequential
mpl.rcParams['font.sans-serif'] = ['SimHei'] #显示中文
mpl.rcParams['axes.unicode_minus']=False #显示负号
取数据
data=pd.read_csv('mock_kaggle.csv',encoding ='gbk',parse_dates=['datetime'])
Date=pd.to_datetime(data.datetime)
data['date'] = Date.map(lambda x: x.strftime('%Y-%m-%d'))
datanew=data.set_index(Date)
series = pd.Series(datanew['股票'].values, index=datanew['date'])
series
date
2014-01-01 4972
2014-01-02 4902
2014-01-03 4843
2014-01-04 4750
2014-01-05 4654
...
2016-07-27 3179
2016-07-28 3071
2016-07-29 4095
2016-07-30 3825
2016-07-31 3642
Length: 937, dtype: int64
滞后扩充数据
dataframe1 = pd.DataFrame()
num_hour = 16
for i in range(num_hour,0,-1):
dataframe1['t-'+str(i)] = series.shift(i)
dataframe1['t'] = series.values
dataframe3=dataframe1.dropna()
dataframe3.index=range(len(dataframe3))
dataframe3
t-16 | t-15 | t-14 | t-13 | t-12 | t-11 | t-10 | t-9 | t-8 | t-7 | t-6 | t-5 | t-4 | t-3 | t-2 | t-1 | t | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 4972.0 | 4902.0 | 4843.0 | 4750.0 | 4654.0 | 4509.0 | 4329.0 | 4104.0 | 4459.0 | 5043.0 | 5239.0 | 5118.0 | 4984.0 | 4904.0 | 4822.0 | 4728.0 | 4464 |
1 | 4902.0 | 4843.0 | 4750.0 | 4654.0 | 4509.0 | 4329.0 | 4104.0 | 4459.0 | 5043.0 | 5239.0 | 5118.0 | 4984.0 | 4904.0 | 4822.0 | 4728.0 | 4464.0 | 4265 |
2 | 4843.0 | 4750.0 | 4654.0 | 4509.0 | 4329.0 | 4104.0 | 4459.0 | 5043.0 | 5239.0 | 5118.0 | 4984.0 | 4904.0 | 4822.0 | 4728.0 | 4464.0 | 4265.0 | 4161 |
3 | 4750.0 | 4654.0 | 4509.0 | 4329.0 | 4104.0 | 4459.0 | 5043.0 | 5239.0 | 5118.0 | 4984.0 | 4904.0 | 4822.0 | 4728.0 | 4464.0 | 4265.0 | 4161.0 | 4091 |
4 | 4654.0 | 4509.0 | 4329.0 | 4104.0 | 4459.0 | 5043.0 | 5239.0 | 5118.0 | 4984.0 | 4904.0 | 4822.0 | 4728.0 | 4464.0 | 4265.0 | 4161.0 | 4091.0 | 3964 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
916 | 1939.0 | 1967.0 | 1670.0 | 1532.0 | 1343.0 | 1022.0 | 813.0 |