时间序列-源码解读-gluonts

关注我的公众号YueTan进行交流探讨
欢迎关注数据比赛方案仓库 https://github.com/hongyingyue/Competition-solutions

  • https://github.com/awslabs/gluonts
  • https://github.com/PaddlePaddle/PaddleTS
  • https://github.com/LongxingTan/Time-series-prediction
  • https://github.com/zalandoresearch/pytorch-ts

这是时间序列的源码解读系列,将会对几个时间序列预测的包进行源码解读,理解其设计和特点。便于优化我自己写的tfts包

整体

规范

  • 每个大版本都有一个git分支

细节

  • 我一直想有的get passenger data, 可以从中借鉴
from gluonts.dataset.repository.datasets import get_dataset

dataset = get_dataset("airpassengers")

为了自动快速训练模型,gluonts提供了几个快速构建数据的方法

  • FileDataset: 从文件中加载
  • ListDataset:样本是在list中一个一个的
    • 依赖的是ProcessDataEntry
      • ProcessTimeSeriesField

def get_sine(train_sequence_length=24, predict_sequence_length=8, test_size=0.2, n_examples=100):
    x = []
    y = []
    for _ in range(n_examples):
        rand = random.random() * 2 * np.pi
        sig1 = np.sin(np.linspace(rand, 3.0 * np.pi + rand, train_sequence_length + predict_sequence_length))
        sig2 = np.cos(np.linspace(rand, 3.0 * np.pi + rand, train_sequence_length + predict_sequence_length))

        x1 = sig1[:train_sequence_length]
        y1 = sig1[train_sequence_length:]
        x2 = sig2[:train_sequence_length]
        y2 = sig2[train_sequence_length:]

        x_ = np.array([x1, x2])
        y_ = np.array([y1, y2])

        x.append(x_.T)
        y.append(y_.T)

    x = np.array(x)[:, :, 0:1]
    y = np.array(y)[:, :, 0:1]
    logging.info("Load sine data", x.shape, y.shape)

    if test_size > 0:
        slice = int(n_examples * (1 - test_size))
        x_train = x[:slice]
        y_train = y[:slice]
        x_valid = x[slice:]
        y_valid = y[slice:]
        return (x_train, y_train), (x_valid, y_valid)
    return x, y

使用

from typing import Dict, List
import json
from gluonts.dataset.common import ListDataset
import subprocess
from gluonts.dataset.field_names import FieldName
from typing import Callable, Iterable, Iterator, List
import torch
from pts.model.deepar import DeepAREstimator
from pts import Trainer
import os
from pathlib import Path
from gluonts.torch.model.predictor import PyTorchPredictor
# import torch.distributed as dist
import numpy as np
import datetime

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class AlgSolution():
    def __init__(self):
        pass

    def train_model(self, input_data_path: str, output_model_path: str, params: Dict, **kwargs) -> bool:
        # """使用数据集训练模型

        # Args:
        #     input_data_path (str): 本地输入数据集路径
        #     output_model_path (str): 本地输出模型路径
        #     params (Dict): 训练输入参数。默认为conf/default.json

        # Returns:
        #     bool: True 成功; False 失败
        # """
        # load pretrained model if any
        # self.model = load_from_pretrained()
        # reading and processing data
        train_data_path = os.path.join(input_data_path, 'train.jsonl')
        print(device)
        train_samples = []
        with open(train_data_path) as f:
            for line in f:
                sample = json.loads(line)
                train_samples.append(sample)

        training_data = ListDataset(
            [{FieldName.TARGET: x['y'],
              FieldName.START: x['start'],
              FieldName.FEAT_STATIC_CAT: [x['item_id'][0]],
              # FieldName.FEAT_DYNAMIC_CAT:np.array([]).T
              } for x in train_samples],
            freq='10min',
        )
        # parameter setting and estimator setting
        config = {
            "epochs": 50,
            "num_batches_per_epoch": 150,
            "batch_size": 2048,
            "context_length": 48,
            "prediction_length":48,
            "lags_seq":[6, 12, 18, 24, 32, 38, 42, 48, 96],
            "num_cells":128,
            "cardinality":[2000],
            "embedding_dimension":[5],
            "learning_rate":1e-3,
        }
        print(config)
        trainer = Trainer(
            epochs=config["epochs"],
            batch_size=config["num_batches_per_epoch"],
            device=device,
            learning_rate=config["learning_rate"],
            num_batches_per_epoch=config["num_batches_per_epoch"])
        estimator = DeepAREstimator(freq="10min",
                                    context_length=config["context_length"],
                                    prediction_length=config["prediction_length"],
                                    input_size=22,
                                    lags_seq=config["lags_seq"],
                                    num_cells=config["num_cells"],
                                    trainer=trainer,
                                    cardinality=config["cardinality"],
                                    embedding_dimension=config["embedding_dimension"],
                                    use_feat_static_cat=True,
                                    )
        # train model
        print('model training start time:{}'.format(datetime.datetime.now()))
        predictor = estimator.train(training_data=training_data,num_workers=8,shuffle_buffer_length=1024)
        print('model trained start time:{}'.format(datetime.datetime.now()))

        # save model
        print('model saving start time:{}'.format(datetime.datetime.now()))
        predictor.serialize(Path(output_model_path))
        print('model saved start time:{}'.format(datetime.datetime.now()))

        cmd = 'cd {} && touch model && tar -czf model.tar.gz model'.format(output_model_path)
        ret, _ = subprocess.getstatusoutput(cmd)
        if ret != 0:
            return False
        return True

    def load_model(self, model_path: str, params: Dict, **kwargs) -> bool:
        """从本地加载模型

        Args:
            model_path (str): 本地模型路径
            params (Dict): 模型输入参数。默认为conf/default.json

        Returns:
            bool: True 成功; False 失败
        """
        # load model
        self.model = PyTorchPredictor.deserialize(Path(model_path),device=device)
        return True

    def predicts(self, sample_list: List[Dict], **kwargs) -> List[Dict]:
        """批量预测

        Args:
            sample_list (List[Dict]): 输入请求内容列表
            kwargs:
                __dataset_root_path (str): 本地输入路径
                __output_root_path (str):  本地输出路径

        Returns:
            List[Dict]: 输出预测结果列表
        """
        # 请将输出图片请放到output_path下
        # input_path = kwargs.get('__dataset_root_path')
        # output_path = kwargs.get('__output_root_path')
        # sample_list [{'':''},{'':''}]
        # 根据输入内容,填写计算的答案
        # inferencing data
        print('inferencing data:{}'.format(datetime.datetime.now()))
        test_data = ListDataset(
                [{FieldName.TARGET: x['y'],
                  FieldName.START: x['start'],
                  FieldName.FEAT_STATIC_CAT: [x['item_id'][0]],
                  # FieldName.FEAT_DYNAMIC_CAT:np.array([]).T
                  } for x in sample_list],
                freq='10min',
            )
        predicts = list(self.model.predict(test_data))
        ret = [{
            'prediction': [float(i) for i in list(pred.samples.mean(axis=0).reshape(-1, ))],
        } for pred in predicts]
        print('inferenced data:{}'.format(datetime.datetime.now()))

        return ret

这个例子是AETC比赛中流量预测的baseline

  • 加载json数据
  • 其中使用三列给数据类型,FieldName.TARGET: x[‘y’], FieldName.START: x[‘start’], FieldName.FEAT_STATIC_CAT: [x[‘item_id’][0]]
  • 模型 前的处理可以看pts-deepar
  • 采用deepar模型

处理思路

  • 样本
    • deepar内置了从一个更长尺度上进行截取变成多个样本
  • 使用特征包括
    • lag
    • item_id embed
    • start_time不知道有没有利用
  • 模型

  • 训练
      1. transformation = create_transformation
      1. net = create_training_network
      1. input_name = get_module_forward_input_name
      1. 如果,create_instance_splitter
      1. TransformedIterableDataset
      1. DataLoader
      1. 如果有valid_data,建立valid
      1. 建立trainer
      1. 返回TrainOutput

这很多思路也可以用于autots的思路中。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

YueTann

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值