GluonTS 时间序列预测框架深度教程:从数据准备到模型训练

GluonTS 时间序列预测框架深度教程:从数据准备到模型训练

gluonts awslabs/gluonts: GluonTS (Gluon Time Series) 是一个由Amazon Web Services实验室维护的时间序列预测库,基于Apache MXNet的Gluon API构建,适用于各种商业应用中复杂时间序列数据的建模和预测任务。 gluonts 项目地址: https://gitcode.com/gh_mirrors/gl/gluonts

前言

时间序列预测是机器学习领域的重要应用场景,涵盖金融、零售、能源等多个行业。GluonTS 作为一个强大的概率时间序列预测工具包,提供了从数据处理到模型训练的全套解决方案。本文将深入讲解 GluonTS 的核心功能,帮助开发者快速掌握这一工具。

环境准备

在开始之前,我们需要导入必要的 Python 库:

import mxnet as mx
from mxnet import gluon
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

设置随机种子保证实验可复现性:

mx.random.seed(0)
np.random.seed(0)

数据准备

1. GluonTS 内置数据集

GluonTS 提供了多个公开可用的时间序列数据集,方便用户快速开始实验:

from gluonts.dataset.repository import get_dataset, dataset_names
from gluonts.dataset.util import to_pandas

# 查看所有可用数据集
print(f"可用数据集: {dataset_names}")

# 加载M4每小时数据集
dataset = get_dataset("m4_hourly")

每个数据集包含三个主要部分:

  • train: 训练集,包含多个时间序列
  • test: 测试集,包含比训练集更长的序列用于评估
  • metadata: 数据集元信息,如预测长度、频率等

2. 可视化分析

理解数据是建模的第一步,我们可以可视化查看时间序列:

train_entry = next(iter(dataset.train))
test_entry = next(iter(dataset.test))

test_series = to_pandas(test_entry)
train_series = to_pandas(train_entry)

fig, ax = plt.subplots(2, 1, figsize=(10, 7))
train_series.plot(ax=ax[0])
test_series.plot(ax=ax[1])
ax[1].axvline(train_series.index[-1], color='r')
plt.show()

3. 创建人工数据集

当我们需要特定模式的数据时,可以创建人工数据集:

from gluonts.dataset.artificial import ComplexSeasonalTimeSeries

artificial_dataset = ComplexSeasonalTimeSeries(
    num_series=10,
    prediction_length=21,
    freq_str="H",
    length_low=30,
    length_high=200
)

4. 自定义数据集

实际项目中,我们通常需要处理自己的数据。GluonTS 要求数据集至少包含:

  • start: 时间序列起始时间
  • target: 时间序列值

可选字段包括静态特征、动态特征等。下面是一个创建自定义数据集的示例:

from gluonts.dataset.common import ListDataset
from gluonts.dataset.field_names import FieldName

# 创建包含100个时间序列的数据集
custom_ds_metadata = {
    'num_series': 100,
    'num_steps': 24*7,
    'prediction_length': 24,
    'freq': '1H'
}

# 转换为GluonTS格式
train_ds = ListDataset(
    [{'target': target, 'start': start} 
     for target, start in zip(...)],
    freq=custom_ds_metadata['freq']
)

数据转换

1. 转换管道

GluonTS 使用转换管道(Transformation)来预处理数据:

from gluonts.transform import *

def create_transformation(freq, context_length, prediction_length):
    return Chain([
        AddObservedValuesIndicator(),  # 添加观测值指示器
        AddAgeFeature(),              # 添加时间年龄特征
        InstanceSplitter(             # 实例分割器
            past_length=context_length,
            future_length=prediction_length
        )
    ])

2. 转换应用

将转换应用到数据集:

transformation = create_transformation(
    custom_ds_metadata['freq'], 
    2*custom_ds_metadata['prediction_length'],
    custom_ds_metadata['prediction_length']
)

train_tf = transformation(iter(train_ds), is_train=True)

转换后的数据包含分割后的时间窗口和各种特征:

train_tf_entry = next(iter(train_tf))
print(f"过去目标形状: {train_tf_entry['past_target'].shape}")
print(f"未来目标形状: {train_tf_entry['future_target'].shape}")

模型训练

1. 模型定义

GluonTS 提供了多种预测模型,以DeepAR为例:

from gluonts.model.deepar import DeepAREstimator
from gluonts.mx.trainer import Trainer

estimator = DeepAREstimator(
    freq=custom_ds_metadata['freq'],
    prediction_length=custom_ds_metadata['prediction_length'],
    trainer=Trainer(epochs=10)
)

2. 训练过程

predictor = estimator.train(train_ds)

预测与评估

1. 生成预测

from gluonts.evaluation import Evaluator

forecast_it, ts_it = make_evaluation_predictions(
    dataset=test_ds,
    predictor=predictor
)

forecasts = list(forecast_it)
tss = list(ts_it)

2. 评估指标

evaluator = Evaluator()
agg_metrics, item_metrics = evaluator(tss, forecasts)

print(f"平均MASE: {agg_metrics['MASE']}")
print(f"平均sMAPE: {agg_metrics['sMAPE']}")

高级特性

1. 使用外部特征

GluonTS 支持将外部特征纳入模型:

# 在数据集中添加动态特征
train_ds = ListDataset(
    [{
        'target': target,
        'start': start,
        'feat_dynamic_real': [dynamic_feature]
    } for ...],
    freq='1H'
)

2. 处理缺失值

通过转换管道处理缺失值:

transformation = Chain([
    AddObservedValuesIndicator(),
    # 其他转换...
])

总结

本文详细介绍了使用 GluonTS 进行时间序列预测的完整流程:

  1. 数据准备:内置数据集、人工数据和自定义数据
  2. 数据转换:特征工程和窗口分割
  3. 模型训练:DeepAR等模型的配置和训练
  4. 预测评估:生成预测和计算评估指标

GluonTS 的强大之处在于其模块化设计,开发者可以灵活组合各种组件,构建适合自己业务场景的预测解决方案。通过本教程,读者应该能够掌握 GluonTS 的核心概念,并能够将其应用到实际项目中。

gluonts awslabs/gluonts: GluonTS (Gluon Time Series) 是一个由Amazon Web Services实验室维护的时间序列预测库,基于Apache MXNet的Gluon API构建,适用于各种商业应用中复杂时间序列数据的建模和预测任务。 gluonts 项目地址: https://gitcode.com/gh_mirrors/gl/gluonts

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

陆可鹃Joey

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值