DeepAR笔记

DeepAR-概率预测模型

介绍

DeepAR是Amazon在2017年提出的基于深度学习的时间序列预测方法,目前在Amazon机器学习平台Amazon SageMaker和其开源的时序预测工具库GluonTS里有集成

RNN

一般常见的RNN模型,如下图,当输入 x t − 1 x_{t-1} xt1时,通过隐藏层和状态的计算后得到 o t − 1 o_{t-1} ot1,然后再输入 x t x_t xt计算后得到 o t o_t ot,其中, o t − 1 o_{t-1} ot1 o t o_{t} ot为具体“预测”的值

相当于输入一个序列X,可以得到输出序列O
当然这种情况下Seq2Seq也是得到一个输出序列

在这里插入图片描述

DeepAR

在DeepAR中,有数据输入后,通过计算得到的输出并不是具体“预测值”,而是先得到一个概率模型,例如高斯概率模型,然后再从这个概率模型中去采样出一个“预测值”

在这里插入图片描述

按照论文中的说法, z i , t z_{i,t} zi,t表示第i个序列在时间t的值, x i , t x_{i,t} xi,t表示额外的特征(协变量)

论文中使用多条时间序列来训练一个模型,使用embedding加以区分不同的时间序列,所以用i来表示序列的编号

所以,DeepAR模型使用的不是常见RNN那种直接预测一个点的模式,而是去预测一个概率,再从这个概率去获得预测值

这样有几个好处

  • 点预测时的结果就是一个点,一个具体的数值,而概率预测相当于是去预测这个点的概率分布,并且可以使用概率分布的特征来描述该点可能出现的范围
  • 当输出一个概率后,采样获得一个数值,然后将这个数值作为下一个输入,这样就增加了一定的鲁棒性。因为这个数值时通过概率采样出来的,每次都会不一样。

Demo

一个简单的Demo

实验设置

  • 实验使用Twitter_volume_AMZN.csv数据,仅包含一条时序数据,没有协变量
  • MXNet,GluonTS
  • 具体可以去了解GluonTS
import mxnet as mx
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from gluonts.model.deepar import DeepAREstimator
from gluonts.trainer import Trainer
from gluonts.dataset.common import ListDataset
from itertools import islice
from gluonts.evaluation.backtest import make_evaluation_predictions


df = pd.read_csv('data/Twitter_volume_AMZN.csv', index_col=0,parse_dates=True)

training_data = ListDataset(
    [{"start": df.index[0], "target": df.value[:"2015-04-10 00:00:00"]}],
    freq = "5min"
)
test_data = ListDataset(
    [{"start": df.index[0], "target": df.value[:"2015-04-15 00:00:00"]},
     {"start": df.index[0], "target": df.value[:"2015-04-16 00:00:00"]},
     {"start": df.index[0], "target": df.value[:"2015-04-17 00:00:00"]}],
    freq = "5min"
)

estimator = DeepAREstimator(freq="5min",
                            prediction_length=24,
                            trainer=Trainer(
                                ctx=mx.context.gpu(),
                                epochs=20,
                                learning_rate=1e-2))

predictor = estimator.train(training_data=training_data)

def plot_forecasts(tss, forecasts, past_length, num_plots):
    for target, forecast in islice(zip(tss, forecasts), num_plots):
        ax = target[-past_length:].plot(figsize=(12, 5), linewidth=2)
        forecast.plot(color='g')
        plt.grid(which='both')
        plt.legend(["observations", "median prediction", "90% confidence interval", "50% confidence interval"],
                    prop={'size':12})
        plt.show()

forecast_it, ts_it = make_evaluation_predictions(test_data, predictor=predictor, num_samples=100)


forecasts = list(forecast_it)
tss = list(ts_it)

plot_forecasts(tss, forecasts, past_length=100, num_plots=3)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

参考

  • https://arxiv.org/abs/1704.04110
  • https://docs.aws.amazon.com/zh_cn/sagemaker/latest/dg/deepar.html
  • https://ts.gluon.ai/

如果不正确的地方请指出,谢谢

  • 4
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值