GluonTS项目中使用PyTorch Lightning构建自定义时间序列模型

GluonTS项目中使用PyTorch Lightning构建自定义时间序列模型

gluonts awslabs/gluonts: GluonTS (Gluon Time Series) 是一个由Amazon Web Services实验室维护的时间序列预测库,基于Apache MXNet的Gluon API构建,适用于各种商业应用中复杂时间序列数据的建模和预测任务。 gluonts 项目地址: https://gitcode.com/gh_mirrors/gl/gluonts

前言

时间序列预测是机器学习领域的重要应用场景之一。GluonTS作为一个功能强大的时间序列预测工具库,提供了丰富的模型和工具。本文将重点介绍如何在GluonTS框架中使用PyTorch Lightning来实现自定义的时间序列预测模型。

准备工作

在开始构建模型前,我们需要准备环境和数据。首先导入必要的Python库:

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import lightning.pytorch as pl
from matplotlib import pyplot as plt
import matplotlib.dates as mdates

数据集介绍

我们将使用电力消耗数据集作为示例。这个数据集包含了多个电力消耗时间序列,非常适合用于时间序列预测模型的训练和测试。

from gluonts.dataset.repository import get_dataset
dataset = get_dataset("electricity")

通过可视化可以直观了解数据特征:

date_formater = mdates.DateFormatter('%Y')
fig = plt.figure(figsize=(12,8))
for idx, entry in enumerate(islice(dataset.train, 9)):
    ax = plt.subplot(3, 3, idx+1)
    t = pd.date_range(start=entry["start"].to_timestamp(), 
                      periods=len(entry["target"]), 
                      freq=entry["start"].freq)
    plt.plot(t, entry["target"])
    plt.xticks(pd.date_range(start="2011-12-31", periods=3, freq="AS"))
    ax.xaxis.set_major_formatter(date_formater)

模型架构设计

我们将构建一个基于前馈神经网络的概率预测模型。该模型的核心特点是:

  1. 使用神经网络输出参数化分布的参数
  2. 默认使用Student's t分布作为输出分布
  3. 支持自定义隐藏层维度
class FeedForwardNetwork(nn.Module):
    def __init__(
        self,
        prediction_length: int,
        context_length: int,
        hidden_dimensions: List[int],
        distr_output = StudentTOutput(),
        batch_norm: bool=False,
        scaling: Callable=mean_abs_scaling,
    ) -> None:
        super().__init__()
        # 初始化代码...
        
    def forward(self, past_target):
        # 前向传播逻辑...
    
    def get_predictor(self, input_transform, batch_size=32):
        # 创建预测器...

使用PyTorch Lightning进行训练

PyTorch Lightning简化了训练过程的管理。我们只需继承基础模型类并实现几个关键方法:

class LightningFeedForwardNetwork(FeedForwardNetwork, pl.LightningModule):
    def training_step(self, batch, batch_idx):
        # 定义训练步骤...
    
    def configure_optimizers(self):
        # 配置优化器...

数据加载与预处理

时间序列数据需要特殊的预处理方式:

  1. 处理缺失值
  2. 创建适当的训练实例
  3. 批量处理数据
mask_unobserved = AddObservedValuesIndicator(
    target_field=FieldName.TARGET,
    output_field=FieldName.OBSERVED_VALUES,
)

training_splitter = InstanceSplitter(
    # 配置实例分割参数...
)

data_loader = TrainDataLoader(
    # 配置数据加载器...
)

模型训练与评估

训练过程非常简单:

trainer = pl.Trainer(max_epochs=10)
trainer.fit(net, data_loader)

训练完成后,我们可以创建预测器并进行评估:

predictor_pytorch = net.get_predictor(mask_unobserved + prediction_splitter)

forecast_it, ts_it = make_evaluation_predictions(
    dataset=dataset.test, predictor=predictor_pytorch
)

# 可视化预测结果
plt.figure(figsize=(20, 15))
# 绘图代码...

性能评估

最后,我们计算模型的评估指标:

evaluator = Evaluator(quantiles=[0.1, 0.5, 0.9])
metrics_pytorch, _ = evaluator(tss_pytorch, forecasts_pytorch)
pd.DataFrame.from_records(metrics_pytorch, index=["FeedForward"]).transpose()

总结

本文详细介绍了在GluonTS框架中使用PyTorch Lightning构建自定义时间序列模型的完整流程。通过这种方法,我们可以充分利用PyTorch的灵活性和PyTorch Lightning的训练管理能力,同时还能与GluonTS的其他组件无缝集成。这种组合为时间序列预测任务提供了强大的工具集。

gluonts awslabs/gluonts: GluonTS (Gluon Time Series) 是一个由Amazon Web Services实验室维护的时间序列预测库,基于Apache MXNet的Gluon API构建,适用于各种商业应用中复杂时间序列数据的建模和预测任务。 gluonts 项目地址: https://gitcode.com/gh_mirrors/gl/gluonts

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

龚格成

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值