DeepAR-概率预测模型
介绍
DeepAR是Amazon在2017年提出的基于深度学习的时间序列预测方法,目前在Amazon机器学习平台Amazon SageMaker和其开源的时序预测工具库GluonTS里有集成
RNN
一般常见的RNN模型,如下图,当输入 x t − 1 x_{t-1} xt−1时,通过隐藏层和状态的计算后得到 o t − 1 o_{t-1} ot−1,然后再输入 x t x_t xt计算后得到 o t o_t ot,其中, o t − 1 o_{t-1} ot−1和 o t o_{t} ot为具体“预测”的值
相当于输入一个序列X
,可以得到输出序列O
当然这种情况下Seq2Seq也是得到一个输出序列
DeepAR
在DeepAR中,有数据输入后,通过计算得到的输出并不是具体“预测值”,而是先得到一个概率模型,例如高斯概率模型,然后再从这个概率模型中去采样出一个“预测值”
按照论文中的说法,
z
i
,
t
z_{i,t}
zi,t表示第i
个序列在时间t
的值,
x
i
,
t
x_{i,t}
xi,t表示额外的特征(协变量)
论文中使用多条时间序列来训练一个模型,使用embedding加以区分不同的时间序列,所以用i
来表示序列的编号
所以,DeepAR模型使用的不是常见RNN那种直接预测一个点的模式,而是去预测一个概率,再从这个概率去获得预测值
这样有几个好处
- 点预测时的结果就是一个点,一个具体的数值,而概率预测相当于是去预测这个点的概率分布,并且可以使用概率分布的特征来描述该点可能出现的范围
- 当输出一个概率后,采样获得一个数值,然后将这个数值作为下一个输入,这样就增加了一定的鲁棒性。因为这个数值时通过概率采样出来的,每次都会不一样。
Demo
一个简单的Demo
实验设置
- 实验使用Twitter_volume_AMZN.csv数据,仅包含一条时序数据,没有协变量
- MXNet,GluonTS
- 具体可以去了解GluonTS
import mxnet as mx
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from gluonts.model.deepar import DeepAREstimator
from gluonts.trainer import Trainer
from gluonts.dataset.common import ListDataset
from itertools import islice
from gluonts.evaluation.backtest import make_evaluation_predictions
df = pd.read_csv('data/Twitter_volume_AMZN.csv', index_col=0,parse_dates=True)
training_data = ListDataset(
[{"start": df.index[0], "target": df.value[:"2015-04-10 00:00:00"]}],
freq = "5min"
)
test_data = ListDataset(
[{"start": df.index[0], "target": df.value[:"2015-04-15 00:00:00"]},
{"start": df.index[0], "target": df.value[:"2015-04-16 00:00:00"]},
{"start": df.index[0], "target": df.value[:"2015-04-17 00:00:00"]}],
freq = "5min"
)
estimator = DeepAREstimator(freq="5min",
prediction_length=24,
trainer=Trainer(
ctx=mx.context.gpu(),
epochs=20,
learning_rate=1e-2))
predictor = estimator.train(training_data=training_data)
def plot_forecasts(tss, forecasts, past_length, num_plots):
for target, forecast in islice(zip(tss, forecasts), num_plots):
ax = target[-past_length:].plot(figsize=(12, 5), linewidth=2)
forecast.plot(color='g')
plt.grid(which='both')
plt.legend(["observations", "median prediction", "90% confidence interval", "50% confidence interval"],
prop={'size':12})
plt.show()
forecast_it, ts_it = make_evaluation_predictions(test_data, predictor=predictor, num_samples=100)
forecasts = list(forecast_it)
tss = list(ts_it)
plot_forecasts(tss, forecasts, past_length=100, num_plots=3)
参考
- https://arxiv.org/abs/1704.04110
- https://docs.aws.amazon.com/zh_cn/sagemaker/latest/dg/deepar.html
- https://ts.gluon.ai/
如果不正确的地方请指出,谢谢