时间序列预测任务PyTorch数据集类——TimeSeriesDataSet 类详解

时间序列预测任务PyTorch数据集类——TimeSeriesDataSet 类详解

当进行时间序列预测或时间序列分析时,通常需要对数据进行预处理和转换以提高模型的效果和准确性。TimeSeriesDataSet 类是为这些目的而创建的 PyTorch 数据集类,提供了一些自动化的功能,使得预处理和转换变得更加方便和高效。该类可以用于多种时间序列预测任务,例如预测股票价格、交通流量、能源消耗等。

PyTorch Dataset for fitting timeseries models.

The dataset automates common tasks such as

* scaling and encoding of variables
* normalizing the target variable
* efficiently converting timeseries in pandas dataframes to torch tensors
* holding information about static and time-varying variables known and unknown in the future
* holiding information about related categories (such as holidays)
* downsampling for data augmentation
* generating inference, validation and test datasets
* etc.

在该类中,一些自动化的功能包括:

  • 变量缩放和编码:对于不同的变量,可以通过指定缩放方法和编码方法来将其归一化,并减少变量之间的差异性,从而提高模型的效果。
  • 目标变量归一化:对于时间序列中的目标变量,可以对其进行归一化,以便更好地适应模型。
  • 数据转换:该类提供了一些方法,可以将时间序列数据从 pandas 数据框中转换为 PyTorch 张量,以便更好地适应 PyTorch 模型。
  • 变量信息的保存:该类可以保存关于未来已知和未知的静态和时变变量的信息,以便更好地处理多个时间步长的时间序列数据。
  • 相关类别信息的保存:该类可以保存与时间序列数据相关的类别信息,例如假日信息等,以便更好地处理具有类别信息的时间序列数据。
  • 数据增强:该类提供了下采样的功能,可以对时间序列数据进行降采样,以便更好地处理长时间序列数据。
  • 数据集生成:该类可以自动生成训练、验证和测试数据集,以便更好地进行模型评估和测试。

这些自动化的功能可以帮助用户更好地处理时间序列数据,提高模型的效果和准确性。

    # todo: refactor:
    # - creating base class with minimal functionality
    # - "outsource" transformations -> use pytorch transformations as default

    # todo: integrate graphs
    # - add option to pass networkx graph to the dataset -> clearly defined
    # - create method to create networkx graph for hierachies -> clearly defined
    # - convert networkx graph to pytorch geometric graph
    # - create sampler to sample from the graph
    # - create option in `to_dataloader` method to use a graph sampler
    #     -> automatically changing collate function which returns graphs
    #     -> should incorporate entire dataset but be compatible with current approach
    # - integrate hierachical loss somehow into loss metrics

    # how to get there:
    # - add networkx and pytorch_geometric to requirements BUT as extras
    #     -> do we also need torch_sparse, etc.? -> can we avoid this? probably not
    # - networkx graph: define what makes sense from user perspective
    # - define conversion into pytorch geometric graph? is this a two-step process of
    #     - encoding networkx graph and converting it into "unfilled" pytorch geometric graph
    #     - then creating full graph in collate function on the fly?
    #     - or is data already stored in pytorch geometric graph and we only cut through it?
    #     - dataformat would change? Is is all timeseries data? + mask when valid?
    #     - then making cuts through the graph in sampling?
    #     - would it be best in this case to re-think the timeseries class and design it as series of transformations?
    #     - what is the new master data? very off current state or very similar?
    #     - current approach is storing data in long format which is memory efficient and using the index object to
    #       make sense of it when accessing. graphs would require wide format?
    # - do NOT overengineer, i.e. support only usecase of single static graph, but only subset might be relevant
    #     -> however, should think what happens if we want a dynamic graph. would this completely change the
    #        data format?

    # decisions:
    # - stay with long format and create graph on the fly even if hampering efficiency and performance
    # - go with pytorch_geometric approach for future proofing
    # - directly convert networkx into pytorch_geometric graph
    # - sampling: support only time-synchronized.
    #     - sample randomly an instance from index as now.
    #     - then get additional samples as per graph (that has been created) and available data
    #     - then collate into graph object

这个注释是一个程序员留下来的任务列表和思考记录,主要是为了让自己或其他开发人员在未来维护和改进这段代码时能够更加明确和清晰地理解代码的逻辑和设计。以下是每个部分的详细解释:

  • “todo: refactor”:需要对代码进行重构,即改进代码的设计、结构和实现方式,以提高代码的可读性、可维护性和可扩展性。
  • “outsource transformations”:把数据变换的功能抽象出来,使用 PyTorch 内置的数据变换函数作为默认选项。
  • “integrate graphs”:将图结构集成到数据集中,以便在训练神经网络时可以使用图结构进行采样和损失计算。
  • “add option to pass networkx graph to the dataset”:增加选项,允许用户将 NetworkX 图传递给数据集,这样可以明确地定义采样和损失计算的图结构。
  • “create method to create networkx graph for hierarchies”:创建方法,用于为分层数据集创建 NetworkX 图。
  • “convert networkx graph to pytorch geometric graph”:将 NetworkX 图转换为 PyTorch Geometric 图,以便在训练神经网络时使用 PyTorch Geometric 库。
  • “create sampler to sample from the graph”:创建采样器,用于从图结构中采样数据。
  • “create option in to_dataloader method to use a graph sampler”:在 to_dataloader 方法中创建选项,以便使用图采样器进行数据加载。
  • “incorporate entire dataset but be compatible with current approach”:需要确保新的图结构方法与现有的数据集兼容,并可以应用于整个数据集。
  • “integrate hierarchical loss somehow into loss metrics”:将分层损失结构集成到损失度量中。
  • “add networkx and pytorch_geometric to requirements BUT as extras”:将 NetworkX 和 PyTorch Geometric 库添加到项目的依赖项中,但作为可选的扩展库,而不是必需的。
  • “define what makes sense from user perspective”:定义从用户角度出发的图结构设计要求。
  • “define conversion into pytorch geometric graph”:定义将 NetworkX 图转换为 PyTorch Geometric 图的过程。
  • “re-think the timeseries class and design it as series of transformations”:重新思考时间序列类的设计,并将其设计为一系列变换操作。
  • “do NOT overengineer”:不要过度设计,即不要支持过于复杂的用例,保持代码的简单和易用性。
  • “stay with long format and create graph on the fly”:保持数据的长格式,动态生成图结构。
  • “go with pytorch_geometric approach for future proofing”:采用 PyTorch Geometric 库的设计方式,以保证代码的可维护性和扩展性。
    def __init__(
        self,
        data: pd.DataFrame,
        time_idx: str,
        target: Union[str, List[str]],
        group_ids: List[str],
        weight: Union[str, None] = None,
        max_encoder_length: int = 30,
        min_encoder_length: int = None,
        min_prediction_idx: int = None,
        min_prediction_length: int = None,
        max_prediction_length: int = 1,
        static_categoricals: List[str] = [],
        static_reals: List[str] = [],
        time_varying_known_categoricals: List[str] = [],
        time_varying_known_reals: List[str] = [],
        time_varying_unknown_categoricals: List[str] = [],
        time_varying_unknown_reals: List[str] = [],
        variable_groups: Dict[str, List[int]] = {},
        constant_fill_strategy: Dict[str, Union[str, float, int, bool]] = {},
        allow_missing_timesteps: bool = False,
        lags: Dict[str, List[int]] = {},
        add_relative_time_idx: bool = False,
        add_target_scales: bool = False,
        add_encoder_length: Union[bool, str] = "auto",
        target_normalizer: Union[NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER]] = "auto",
        categorical_encoders: Dict[str, NaNLabelEncoder] = {},
        scalers: Dict[str, Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer]] = {},
        randomize_length: Union[None, Tuple[float, float], bool] = False,
        predict_mode: bool = False,
    ):
  • data: 包含时间序列的DataFrame或numpy数组。
  • group_ids: 每个时间序列的ID,用于区分时间序列。
  • target: 要预测的目标变量的名称或索引。
  • static_categoricals: 静态分类特征的名称或索引。
  • static_reals: 静态连续特征的名称或索引。
  • time_varying_known_categoricals: 已知的时间变化分类特征的名称或索引。
  • time_varying_known_reals: 已知的时间变化连续特征的名称或索引。
  • time_varying_unknown_categoricals: 未知的时间变化分类特征的名称或索引。
  • time_varying_unknown_reals: 未知的时间变化连续特征的名称或索引。
  • max_encoder_length: 编码器最大长度。
  • max_prediction_length: 预测器最大长度。
  • train_sampler: 用于采样训练数据的采样器。
  • val_sampler: 用于采样验证数据的采样器。
  • batch_size: 批量大小。
  • num_workers: 加载数据的进程数量。
  • scalers: 包含用于缩放数据的scikit-learn scalers的字典。
  • randomize_length: 用于随机化长度的参数。
  • predict_mode: 是否仅迭代每个时间序列一次,即仅使用最后一批提供的样本作为预测样本。

这是一个 Python 类的构造函数,用于构造一个时序数据集。该数据集用于训练时序模型。下面是这个构造函数的参数解释:

  • data: pd.DataFrame。存储时间序列数据的 DataFrame。每行数据都可以由时间索引(time_idx)和 group_ids 确定。
  • time_idx: str。表示时间的列名。该列用于确定样本的时间序列。
  • target: Union[str, List[str]]。目标列或目标列的列表,可以是分类变量或连续变量。
  • group_ids: List[str]。表示时间序列的列名的列表。这意味着 group_ids 与 time_idx 一起确定样本。如果只有一个时间序列,则将其设置为恒定的列名即可。
  • weight: Union[str, None]。权重的列名。默认为 None。
  • max_encoder_length: int。最大编码长度。这是时间序列数据集使用的最大历史长度。
  • min_encoder_length: int。允许的最小编码长度。默认为 max_encoder_length。
  • min_prediction_idx: int。从哪个时间索引开始进行预测。该参数可以用于创建验证或测试集。
  • max_prediction_length: int。最大预测/解码长度(不要选择太短的长度,因为它可能会导致难以收敛)。
  • min_prediction_length: int。最小预测/解码长度。默认为 max_prediction_length。
  • static_categoricals: List[str]。静态分类变量的列表,这些变量随时间不变,条目也可以是列表,然后将它们一起编码(例如,产品类别很有用)。
  • static_reals: List[str]。不随时间变化的连续变量的列表。
  • time_varying_known_categoricals: List[str]。随时间变化但未来已知的分类变量的列表,条目也可以是列表,然后将它们一起编码(例如,特殊日期或促销类别很有用)。
  • time_varying_known_reals: List[str]。随时间变化但未来已知的连续变量的列表(例如,产品的价格,但不是产品的需求)。
  • time_varying_unknown_categoricals: List[str]。随时间变化且未来未知的分类变量的列表,条目也可以是列表,然后将它们一起编码。
  • time_varying_unknown_reals: List[str]。随时间变化且未来未知的连续变量的列表。
  • variable_groups: Dict[str, List[int]]。将变量分组的字典,其中键是组名,值是变量的索引列表。
  • constant_fill_strategy: Dict[str, Union[str, float, int, bool]]。常数填充策略的字典,其中键是列名,值是填充的常数或字符串,可以是 “ffill” 或 “bfill”。
  • allow_missing_timesteps (bool):是否允许数据中存在缺失时间步,并在数据集中自动填充。缺失时间步是指时间序列中存在间隔,例如某个时间序列仅包含时间步1、2、4、5,那么时间步3会在生成数据集时自动生成。但是,此参数不处理缺失值(NA值)。在将数据帧传递给TimeSeriesDataSet之前,应先填充NA值。
  • lags (Dict[str, List[int]]):用于定义变量的滞后时间步的字典。滞后时间步可用于向模型指示数据的季节性。如果您知道数据的季节性,则至少应添加目标变量及其相应的滞后时间步以提高性能。滞后时间步不能大于最短时间序列,并且所有时间序列都将被最大滞后时间步值截断,以防止NA值。滞后变量必须出现在时间变化的变量中。如果只需要滞后值而不需要当前值,则可以在输入数据中手动滞后。
  • add_relative_time_idx (bool):是否将相对时间索引作为特征添加到数据集中。对于每个采样序列,索引将从-encoder_length到prediction_length范围内变化。
  • add_target_scales (bool):如果将目标的中心和比例作为特征添加到静态实值特征中。即将非标准化时间序列的中心和比例作为特征添加到数据集中。
  • add_encoder_length (bool):是否将编码器长度添加到静态实值变量列表中。默认情况下为“auto”,即当“min_encoder_length != max_encoder_length”时为“True”。
  • target_normalizer (Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer, str, list, tuple]):用于对目标进行归一化的转换器。可从TorchNormalizer、GroupNormalizer、NaNLabelEncoder、EncoderNormalizer或None中进行选择。对于多个目标,请使用MultiNormalizer。默认情况下,将自动选择适当的归一化器。
  • categorical_encoders (Dict[str, NaNLabelEncoder]):scikit-learn标签变换器的字典。如果未来存在未观察到的类别/冷启动问题,则可以使用NaNLabelEncoder并设置“add_nan=True”。默认情况下为scikit-learn的“LabelEncoder()”。预先拟合的编码器将不再进行拟合。
  • scalers的字典变量,包含多种不同的标准化方法,例如scikit-learn库中的StandardScaler、RobustScaler,以及pytorch_forecasting库中的EncoderNormalizer和GroupNormalizer等等。默认情况下,使用scikit-learn库中的StandardScaler作为标准化方法。如果要使用其他标准化方法,可以在字典中指定相应的方法名称。此外,还可以选择不使用标准化方法,即使用None,或者使用具有center=0和scale=1参数的标准化方法(method=“identity”)。对于预先拟合好的编码器(encoder)序列,除了EncoderNormalizer需要在每个编码器序列上进行拟合外,其他标准化方法都不需要再次拟合。
  • randomize_length的变量,用于控制是否随机对序列长度进行采样,以及采样的方式。最后,还有一个名为“predict_mode”的布尔变量,用于控制模型是否只迭代一次序列,即只使用每个时间序列的最后几个样本进行预测。
        super().__init__()
        self.max_encoder_length = max_encoder_length
        assert isinstance(self.max_encoder_length, int), "max encoder length must be integer"
        if min_encoder_length is None:
            min_encoder_length = max_encoder_length
        self.min_encoder_length = min_encoder_length
        assert (
            self.min_encoder_length <= self.max_encoder_length
        ), "max encoder length has to be larger equals min encoder length"
        assert isinstance(self.min_encoder_length, int), "min encoder length must be integer"
        self.max_prediction_length = max_prediction_length
        assert isinstance(self.max_prediction_length, int), "max prediction length must be integer"
        if min_prediction_length is None:
            min_prediction_length = max_prediction_length
        self.min_prediction_length = min_prediction_length
        assert (
            self.min_prediction_length <= self.max_prediction_length
        ), "max prediction length has to be larger equals min prediction length"
        assert self.min_prediction_length > 0, "min prediction length must be larger than 0"
        assert isinstance(self.min_prediction_length, int), "min prediction length must be integer"
        assert data[time_idx].dtype.kind == "i", "Timeseries index should be of type integer"
        self.target = target
        self.weight = weight
        self.time_idx = time_idx
        self.group_ids = [] + group_ids
        self.static_categoricals = [] + static_categoricals
        self.static_reals = [] + static_reals
        self.time_varying_known_categoricals = [] + time_varying_known_categoricals
        self.time_varying_known_reals = [] + time_varying_known_reals
        self.time_varying_unknown_categoricals = [] + time_varying_unknown_categoricals
        self.time_varying_unknown_reals = [] + time_varying_unknown_reals
        self.add_relative_time_idx = add_relative_time_idx

        # set automatic defaults
        if isinstance(randomize_length, bool):
            if not randomize_length:
                randomize_length = None
            else:
                randomize_length = (0.2, 0.05)
        self.randomize_length = randomize_length
        if min_prediction_idx is None:
            min_prediction_idx = data[self.time_idx].min()
        self.min_prediction_idx = min_prediction_idx
        self.constant_fill_strategy = {} if len(constant_fill_strategy) == 0 else constant_fill_strategy
        self.predict_mode = predict_mode
        self.allow_missing_timesteps = allow_missing_timesteps
        self.target_normalizer = target_normalizer
        self.categorical_encoders = {} if len(categorical_encoders) == 0 else categorical_encoders
        self.scalers = {} if len(scalers) == 0 else scalers
        self.add_target_scales = add_target_scales
        self.variable_groups = {} if len(variable_groups) == 0 else variable_groups
        self.lags = {} if len(lags) == 0 else lags
  • super().init():调用父类的构造函数来初始化该类的对象。

  • self.max_encoder_length = max_encoder_length:定义最大编码器长度的参数,即用于训练模型的时间序列数据的最大长度。

  • assert isinstance(self.max_encoder_length, int), “max encoder length must be integer”:断言最大编码器长度的值必须是整数类型,否则会抛出一个异常。

  • if min_encoder_length is None: min_encoder_length = max_encoder_length:如果最小编码器长度的参数未指定,则将其默认设置为最大编码器长度。

  • self.min_encoder_length = min_encoder_length:定义最小编码器长度的参数,即用于训练模型的时间序列数据的最小长度。

  • assert (self.min_encoder_length <= self.max_encoder_length), “max encoder length has to be larger equals min encoder length”:断言最大编码器长度必须大于等于最小编码器长度,否则会抛出一个异常。

  • assert isinstance(self.min_encoder_length, int), “min encoder length must be integer”:断言最小编码器长度的值必须是整数类型,否则会抛出一个异常。

  • self.max_prediction_length = max_prediction_length:定义最大预测长度的参数,即模型预测的时间序列数据的最大长度。

  • assert isinstance(self.max_prediction_length, int), “max prediction length must be integer”:断言最大预测长度的值必须是整数类型,否则会抛出一个异常。

  • if min_prediction_length is None: min_prediction_length = max_prediction_length:如果最小预测长度的参数未指定,则将其默认设置为最大预测长度。

  • self.min_prediction_length = min_prediction_length:定义最小预测长度的参数,即模型预测的时间序列数据的最小长度。

  • assert (self.min_prediction_length <= self.max_prediction_length), “max prediction length has to be larger equals min prediction length”:断言最大预测长度必须大于等于最小预测长度,否则会抛出一个异常。

  • assert self.min_prediction_length > 0, “min prediction length must be larger than 0”:断言最小预测长度必须大于0,否则会抛出一个异常。

  • assert isinstance(self.min_prediction_length, int), “min prediction length must be integer”:断言最小预测长度的值必须是整数类型,否则会抛出一个异常。

  • assert data[time_idx].dtype.kind == “i”, “Timeseries index should be of type integer”:断言时间序列数据的索引必须是整数类型,否则会抛出一个异常。

  • self.target = target:定义目标变量的名称。

  • self.weight = weight:定义权重变量的名称。

  • self.time_idx = time_idx:将输入的时间序列数据中用作时间索引的列的索引值存储在self.time_idx中。

  • self.group_ids = [] + group_ids:将输入的组ID列表group_ids复制到self.group_ids中。

  • self.static_categoricals = [] + static_categoricals:将输入的静态分类变量列表static_categoricals复制到self.static_categoricals中。

  • self.static_reals = [] + static_reals:将输入的静态实数变量列表static_reals复制到self.static_reals中。

  • self.time_varying_known_categoricals = [] + time_varying_known_categoricals:将输入的已知时间变化的分类变量列表time_varying_known_categoricals复制到self.time_varying_known_categoricals中。

  • self.time_varying_known_reals = [] + time_varying_known_reals:将输入的已知时间变化的实数变量列表time_varying_known_reals复制到self.time_varying_known_reals中。

  • self.time_varying_unknown_categoricals = [] + time_varying_unknown_categoricals:将输入的未知时间变化的分类变量列表time_varying_unknown_categoricals复制到self.time_varying_unknown_categoricals中。

  • self.time_varying_unknown_reals = [] + time_varying_unknown_reals:将输入的未知时间变化的实数变量列表time_varying_unknown_reals复制到self.time_varying_unknown_reals中。

  • self.add_relative_time_idx = add_relative_time_idx:如果为True,则将时间变量的相对时间索引添加到模型中。

此外,下面这段代码是设置自动默认值:

  • if isinstance(randomize_length, bool)::如果randomize_length是布尔值类型:
  • if not randomize_length::如果randomize_length是False,则将randomize_length设置为None。
    else::否则:
  • randomize_length = (0.2, 0.05):将randomize_length设置为(0.2, 0.05),即随机裁剪长度的默认范围。
  • self.randomize_length: 用于控制是否对每个序列的长度进行随机化处理。如果 randomize_length 为布尔类型,则判断是否需要进行随机化;如果需要随机化,则随机化的长度范围由一个元组 (p, len) 控制,其中 p 是随机化的概率,len 是随机化后的长度变化比例。如果 randomize_length 为 None,则不对序列长度进行随机化处理。
  • self.min_prediction_idx: 最小的预测时间索引,它等于 data 数组中时间索引的最小值。该变量用于对每个序列的预测时间索引进行限制,以避免预测过去的时间。
  • self.constant_fill_strategy: 常量填充策略,它是一个字典类型,用于指定在缺失时间步长上如何填充常量。如果字典为空,则不进行常量填充。
  • self.predict_mode: 预测模式,它控制模型是用来做单步预测还是多步预测。如果为 predict,则进行多步预测;如果为 repeat,则进行单步预测。
  • self.allow_missing_timesteps: 一个布尔值,用于控制是否允许缺失时间步长。如果为 True,则允许缺失时间步长,缺失的时间步将通过 self.constant_fill_strategy 填充;如果为 False,则不允许缺失时间步长,缺失的时间步将被忽略。
  • self.target_normalizer: 目标变量的标准化器,用于标准化目标变量的取值范围。
  • self.categorical_encoders: 分类变量的编码器,是一个字典类型,用于对分类变量进行编码处理。
  • self.scalers: 变量缩放器,是一个字典类型,用于对数值变量进行缩放处理。
  • self.add_target_scales: 一个布尔值,用于控制是否需要对目标变量进行缩放处理。
  • self.variable_groups: 变量组,是一个字典类型,用于将不同的变量分组。
  • self.lags: 时滞值,是一个字典类型,用于将不同的变量的时滞值指定为不同的值。
       # add_encoder_length
        if isinstance(add_encoder_length, str):
            assert (
                add_encoder_length == "auto"
            ), f"Only 'auto' allowed for add_encoder_length but found {add_encoder_length}"
            add_encoder_length = self.min_encoder_length != self.max_encoder_length
        assert isinstance(
            add_encoder_length, bool
        ), f"add_encoder_length should be boolean or 'auto' but found {add_encoder_length}"
        self.add_encoder_length = add_encoder_length

        # target normalizer
        self._set_target_normalizer(data)

        # overwrite values
        self.reset_overwrite_values()

        for target in self.target_names:
            assert (
                target not in self.time_varying_known_reals
            ), f"target {target} should be an unknown continuous variable in the future"

add_encoder_length
如果add_encoder_length是一个字符串,必须是"auto",否则会出现断言错误,只允许使用’auto’自动计算是否要添加encoder_length。当self.min_encoder_length不等于self.max_encoder_length时,add_encoder_length将被设置为True。如果add_encoder_length是一个布尔值或’auto’字符串,那么self.add_encoder_length将被设置为add_encoder_length的值。

target normalizer
设置目标变量的标准化器

overwrite values
重置overwrite_values,即清除之前的已覆盖值。

assert
断言目标变量不应该出现在self.time_varying_known_reals中,而应该是未知的连续变量。

        # add time index relative to prediction position
        if self.add_relative_time_idx or self.add_encoder_length:
            data = data.copy()  # only copies indices (underlying data is NOT copied)
        if self.add_relative_time_idx:
            assert (
                "relative_time_idx" not in data.columns
            ), "relative_time_idx is a protected column and must not be present in data"
            if "relative_time_idx" not in self.time_varying_known_reals and "relative_time_idx" not in self.reals:
                self.time_varying_known_reals.append("relative_time_idx")
            data.loc[:, "relative_time_idx"] = 0.0  # dummy - real value will be set dynamiclly in __getitem__()

add time index relative to prediction position
如果self.add_relative_time_idx或self.add_encoder_length为True,则复制数据,其中只复制索引(不复制基础数据)。如果self.add_relative_time_idx为True,则确保"data"中不存在"relative_time_idx"列,如果不存在,将"relative_time_idx"添加到self.time_varying_known_reals中,并在data中添加"relative_time_idx"列,并将其所有值设置为0.0(在__getitem__()中将动态设置实际值)

        # add decoder length to static real variables
        if self.add_encoder_length:
            assert (
                "encoder_length" not in data.columns
            ), "encoder_length is a protected column and must not be present in data"
            if "encoder_length" not in self.time_varying_known_reals and "encoder_length" not in self.reals:
                self.static_reals.append("encoder_length")
            data.loc[:, "encoder_length"] = 0  # dummy - real value will be set dynamiclly in __getitem__()

        # validate
        self._validate_data(data)
        assert data.index.is_unique, "data index has to be unique"


add decoder length to static real variables
如果self.add_encoder_length为True,则确保"data"中不存在"encoder_length"列,如果不存在,将"encoder_length"添加到self.static_reals中,并在data中添加"encoder_length"列,并将其所有值设置为0(在__getitem__()中将动态设置实际值)

validate
验证数据,确保数据符合要求

        # add lags
        assert self.min_lag > 0, "lags should be positive"
        if len(self.lags) > 0:
            # add variables
            for name in self.lags:
                lagged_names = self._get_lagged_names(name)
                for lagged_name in lagged_names:
                    assert (
                        lagged_name not in data.columns
                    ), f"{lagged_name} is a protected column and must not be present in data"
                # add lags
                if name in self.time_varying_known_reals:
                    for lagged_name in lagged_names:
                        if lagged_name not in self.time_varying_known_reals:
                            self.time_varying_known_reals.append(lagged_name)
                elif name in self.time_varying_known_categoricals:
                    for lagged_name in lagged_names:
                        if lagged_name not in self.time_varying_known_categoricals:
                            self.time_varying_known_categoricals.append(lagged_name)
                elif name in self.time_varying_unknown_reals:
                    for lagged_name, lag in lagged_names.items():
                        if lag < self.max_prediction_length:  # keep in unknown as if lag is too small
                            if lagged_name not in self.time_varying_unknown_reals:
                                self.time_varying_unknown_reals.append(lagged_name)
                        else:
                            if lagged_name not in self.time_varying_known_reals:
                                # switch to known so that lag can be used in decoder directly
                                self.time_varying_known_reals.append(lagged_name)
                elif name in self.time_varying_unknown_categoricals:
                    for lagged_name, lag in lagged_names.items():
                        if lag < self.max_prediction_length:  # keep in unknown as if lag is too small
                            if lagged_name not in self.time_varying_unknown_categoricals:
                                self.time_varying_unknown_categoricals.append(lagged_name)
                        if lagged_name not in self.time_varying_known_categoricals:
                            # switch to known so that lag can be used in decoder directly
                            self.time_varying_known_categoricals.append(lagged_name)
                else:
                    raise KeyError(f"lagged variable {name} is not a known nor unknown time-varying variable")

add lags
断言lags应该是正数。如果lags长度大于0,则将每个lag添加到相应的变量中。如果lagged_name不存在于data.columns中,则将其添加到相应的列表中。如果name在self.time_varying_known_reals中,则将lagged_name添加到self.time_varying_known_reals列表中,如果name在self.time_varying_known_categoricals中,则将lagged_name添加到self.time_varying_known_categoricals列表中,如果name在self.time_varying_unknown_reals中,则将lagged_name添加到self.time_varying_unknown_reals列表中(如果lag小于self.max_prediction_length),否则将其添加到self.time_varying_known_reals列表中(以便可以直接在解码器中使用lag)。如果name在self.time_varying_unknown_categoricals中,则将lagged_name添加到self.time_varying_unknown_categoricals列表中(如果lag小于self.max_prediction_length),否则将其添加到self.time_varying_known_categoricals列表中(以便可以直接在解码器中使用lag)。如果name不在已知或未知的时间变量中,将引发KeyError异常。

        # filter data
        if min_prediction_idx is not None:
            # filtering for min_prediction_idx will be done on subsequence level ensuring
            # minimal decoder index is always >= min_prediction_idx
            data = data[lambda x: x[self.time_idx] >= self.min_prediction_idx - self.max_encoder_length - self.max_lag]
        data = data.sort_values(self.group_ids + [self.time_idx])

        # preprocess data
        data = self._preprocess_data(data)
        for target in self.target_names:
            assert target not in self.scalers, "Target normalizer is separate and not in scalers."

        # create index
        self.index = self._construct_index(data, predict_mode=self.predict_mode)

        # convert to torch tensor for high performance data loading later
        self.data = self._data_to_tensors(data)

这段代码主要用于数据预处理和构建PyTorch Tensor以用于数据加载。解释如下:

  • 如果min_prediction_idx不为None,则对数据进行过滤,以确保最小的解码器索引始终大于等于min_prediction_idx。这个过滤会在子序列级别进行,确保min_prediction_idx是可行的。筛选后的数据将被用于后续的预处理。过滤后的数据通过在DataFrame上使用lambda表达式完成。
  • 对筛选后的数据进行排序,按照group_ids和time_idx进行排序。这个排序是为了使数据能够在后续的操作中进行更有效的分组和处理。
  • 对数据进行预处理,包括标准化和缺失值填充。对于每个目标变量,确保其没有在缩放器中。
  • 构建索引。使用预处理后的数据构建索引,以便后续的批处理和训练中使用。
    将数据转换为PyTorch Tensor,以便后续的高性能数据加载。转换后的数据将存储在self.data中,供后续使用。
 @property
    def dropout_categoricals(self) -> List[str]:
        """
        list of categorical variables that are unknown when making a
        forecast without observed history
        """
        return [name for name, encoder in self.categorical_encoders.items() if encoder.add_nan]

    def _get_lagged_names(self, name: str) -> Dict[str, int]:
        """
        Generate names for lagged variables

        Args:
            name (str): name of variable to lag

        Returns:
            Dict[str, int]: dictionary mapping new variable names to lags
        """
        return {f"{name}_lagged_by_{lag}": lag for lag in self.lags.get(name, [])}

    @property
    @lru_cache(None)
    def lagged_variables(self) -> Dict[str, str]:
        """
        Lagged variables.

        Returns:
            Dict[str, str]: dictionary of variable names corresponding to lagged variables
                mapped to variable that is lagged
        """
        vars = {}
        for name in self.lags:
            vars.update({lag_name: name for lag_name in self._get_lagged_names(name)})
        return vars

    @property
    @lru_cache(None)
    def lagged_targets(self) -> Dict[str, str]:
        """Subset of `lagged_variables` but only includes variables that are lagged targets."""
        vars = {}
        for name in self.lags:
            vars.update({lag_name: name for lag_name in self._get_lagged_names(name) if name in self.target_names})
        return vars

    @property
    @lru_cache(None)
    def min_lag(self) -> int:
        """
        Minimum number of time steps variables are lagged.

        Returns:
            int: minimum lag
        """
        if len(self.lags) == 0:
            return 1e9
        else:
            return min([min(lag) for lag in self.lags.values()])

    @property
    @lru_cache(None)
    def max_lag(self) -> int:
        """
        Maximum number of time steps variables are lagged.

        Returns:
            int: maximum lag
        """
        if len(self.lags) == 0:
            return 0
        else:
            return max([max(lag) for lag in self.lags.values()])

这是一个Python类中的一些属性和方法,这些属性和方法主要用于处理时间序列数据的特征工程。下面对每个属性和方法进行逐一解释:

@property:这是一个Python中的装饰器,用于将方法转换为类属性。

dropout_categoricals:这是一个类属性方法,返回一个列表,其中包含了那些在预测过程中没有观测历史数据的分类变量的名称。具体来说,这个列表中包含了所有满足 encoder.add_nan=True 的分类变量名称。

_get_lagged_names:这是一个类方法,用于生成具有滞后特性的变量名称。具体来说,这个方法接受一个变量名称作为参数,并返回一个字典,其中包含了具有特定滞后量的变量名称及其对应的滞后步数。

lagged_variables:这是一个类属性方法,返回一个字典,其中包含了所有具有滞后特性的变量名称及其对应的原始变量名称。具体来说,这个字典中包含了所有变量名称的滞后版本,这些变量由方法 _get_lagged_names 生成。

lagged_targets:这是一个类属性方法,返回一个字典,其中包含了所有具有滞后特性的目标变量名称及其对应的原始变量名称。具体来说,这个字典中只包含了满足 name in self.target_names 的目标变量名称的滞后版本,这些变量同样由方法 _get_lagged_names 生成。

min_lag:这是一个类属性方法,返回所有具有滞后特性的变量中最小的滞后步数。具体来说,这个方法返回满足 min(lag) 的最小的滞后步数。

max_lag:这是一个类属性方法,返回所有具有滞后特性的变量中最大的滞后步数。具体来说,这个方法返回满足 max(lag) 的最大的滞后步数。

 def _set_target_normalizer(self, data: pd.DataFrame):
        """
        Determine target normalizer.

        Args:
            data (pd.DataFrame): input data
        """
        if isinstance(self.target_normalizer, str) and self.target_normalizer == "auto":
            normalizers = []
            for target in self.target_names:
                if data[target].dtype.kind != "f":  # category
                    normalizers.append(NaNLabelEncoder())
                    if self.add_target_scales:
                        warnings.warn("Target scales will be only added for continous targets", UserWarning)
                else:
                    data_positive = (data[target] > 0).all()
                    if data_positive:
                        if data[target].skew() > 2.5:
                            transformer = "log"
                        else:
                            transformer = "relu"
                    else:
                        transformer = None
                    if self.max_encoder_length > 20 and self.min_encoder_length > 1:
                        normalizers.append(EncoderNormalizer(transformation=transformer))
                    else:
                        normalizers.append(GroupNormalizer(transformation=transformer))
            if self.multi_target:
                self.target_normalizer = MultiNormalizer(normalizers)
            else:
                self.target_normalizer = normalizers[0]
        elif isinstance(self.target_normalizer, (tuple, list)):
            self.target_normalizer = MultiNormalizer(self.target_normalizer)
        elif self.target_normalizer is None:
            self.target_normalizer = TorchNormalizer(method="identity")
        assert (
            not isinstance(self.target_normalizer, EncoderNormalizer)
            or self.min_encoder_length >= self.target_normalizer.min_length
        ), "EncoderNormalizer is only allowed if min_encoder_length > 1"
        assert isinstance(
            self.target_normalizer, (TorchNormalizer, NaNLabelEncoder)
        ), f"target_normalizer has to be either None or of class TorchNormalizer but found {self.target_normalizer}"
        assert not self.multi_target or isinstance(self.target_normalizer, MultiNormalizer), (
            "multiple targets / list of targets requires MultiNormalizer as target_normalizer "
            f"but found {self.target_normalizer}"

这段代码是用来确定目标变量标准化器(target normalizer)的,目标变量指的是需要被预测的变量。该函数接受一个数据框作为参数,其中包含需要被预测的变量。

首先,函数会检查目标标准化器的类型。如果目标标准化器是字符串并且等于 “auto”,则会根据目标变量的数据类型自动选择标准化器。如果目标变量是分类变量,则使用 NaNLabelEncoder 标准化器,否则判断目标变量的分布情况,如果数据都是正数,且数据偏度大于2.5,则使用 log 变换,否则使用 ReLU 变换。如果需要对时间序列数据进行标准化,则会考虑 EncoderNormalizer 和 GroupNormalizer 两种标准化器,具体选择哪一种标准化器取决于数据集的长度和 EncoderNormalizer 的最小长度参数 min_length。如果是多目标预测,则使用 MultiNormalizer。

如果目标标准化器是元组或列表,则使用 MultiNormalizer。如果目标标准化器为 None,则使用 TorchNormalizer。最后,函数对目标标准化器的类型进行检查和断言,确保标准化器的正确性。

@property
    @lru_cache(None)
    def _group_ids_mapping(self) -> Dict[str, str]:
        """
        Mapping of group id names to group ids used to identify series in dataset -
        group ids can also be used for target normalizer.
        The former can change from training to validation and test dataset while the later must not.
        """
        return {name: f"__group_id__{name}" for name in self.group_ids}

    @property
    @lru_cache(None)
    def _group_ids(self) -> List[str]:
        """
        Group ids used to identify series in dataset.

        See :py:meth:`~TimeSeriesDataSet._group_ids_mapping` for details.
        """
        return list(self._group_ids_mapping.values())

    def _validate_data(self, data: pd.DataFrame):
        """
        Validate that data will not cause hick-ups later on.
        """
        # check for numeric categoricals which can cause hick-ups in logging in tensorboard
        category_columns = data.head(1).select_dtypes("category").columns
        object_columns = data.head(1).select_dtypes(object).columns
        for name in self.flat_categoricals:
            if name not in data.columns:
                raise KeyError(f"variable {name} specified but not found in data")
            if not (
                name in object_columns
                or (name in category_columns and data[name].cat.categories.dtype.kind not in "bifc")
            ):
                raise ValueError(
                    f"Data type of category {name} was found to be numeric - use a string type / categorified string"
                )
        # check for "." in column names
        columns_with_dot = data.columns[data.columns.str.contains(r"\.")]
        if len(columns_with_dot) > 0:
            raise ValueError(
                f"column names must not contain '.' characters. Names {columns_with_dot.tolist()} are invalid"
            )

    def save(self, fname: str) -> None:
        """
        Save dataset to disk

        Args:
            fname (str): filename to save to
        """
        torch.save(self, fname)

    @classmethod
    def load(cls, fname: str):
        """
        Load dataset from disk

        Args:
            fname (str): filename to load from

        Returns:
            TimeSeriesDataSet
        """
        obj = torch.load(fname)
        assert isinstance(obj, cls), f"Loaded file is not of class {cls}"
        return obj

这段代码定义了一个类 TimeSeriesDataSet,包括以下几个方法:

@property和@lru_cache(None)装饰器修饰的_group_ids_mapping和_group_ids方法,用于获取group_ids和它们的映射关系,以便于在数据集中识别系列。
_validate_data方法,用于验证数据集是否符合一些要求,例如类别列是否为字符串类型。
save和load方法,用于将数据集保存和加载到磁盘中,这里使用了PyTorch的torch.save和torch.load方法。
此外,还有一些注释,解释了这些方法的功能,例如_group_ids_mapping方法返回一个映射关系字典,用于将group id名称映射到数据集中用于识别系列的group id。_validate_data方法用于验证数据是否符合一些要求,例如类别列必须为字符串类型,否则会导致记录TensorBoard时出现问题。save和load方法用于将数据集保存到磁盘上,方便之后的训练和推理使用。

    def _preprocess_data(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Scale continuous variables, encode categories and set aside target and weight.

        Args:
            data (pd.DataFrame): original data

        Returns:
            pd.DataFrame: pre-processed dataframe
        """
        # add lags to data
        for name in self.lags:
            # todo: add support for variable groups
            assert (
                name not in self.variable_groups
            ), f"lagged variables that are in {self.variable_groups} are not supported yet"
            for lagged_name, lag in self._get_lagged_names(name).items():
                data[lagged_name] = data.groupby(self.group_ids, observed=True)[name].shift(lag)

        # encode group ids - this encoding
        for name, group_name in self._group_ids_mapping.items():
            # use existing encoder - but a copy of it not too loose current encodings
            encoder = deepcopy(self.categorical_encoders.get(group_name, NaNLabelEncoder()))
            self.categorical_encoders[group_name] = encoder.fit(data[name].to_numpy().reshape(-1), overwrite=False)
            data[group_name] = self.transform_values(name, data[name], inverse=False, group_id=True)

        # encode categoricals first to ensure that group normalizer for relies on encoded categories
        if isinstance(
            self.target_normalizer, (GroupNormalizer, MultiNormalizer)
        ):  # if we use a group normalizer, group_ids must be encoded as well
            group_ids_to_encode = self.group_ids
        else:
            group_ids_to_encode = []
        for name in dict.fromkeys(group_ids_to_encode + self.categoricals):
            if name in self.lagged_variables:
                continue  # do not encode here but only in transform
            if name in self.variable_groups:  # fit groups
                columns = self.variable_groups[name]
                if name not in self.categorical_encoders:
                    self.categorical_encoders[name] = NaNLabelEncoder().fit(data[columns].to_numpy().reshape(-1))
                elif self.categorical_encoders[name] is not None:
                    try:
                        check_is_fitted(self.categorical_encoders[name])
                    except NotFittedError:
                        self.categorical_encoders[name] = self.categorical_encoders[name].fit(
                            data[columns].to_numpy().reshape(-1)
                        )
            else:
                if name not in self.categorical_encoders:
                    self.categorical_encoders[name] = NaNLabelEncoder().fit(data[name])
                elif self.categorical_encoders[name] is not None and name not in self.target_names:
                    try:
                        check_is_fitted(self.categorical_encoders[name])
                    except NotFittedError:
                        self.categorical_encoders[name] = self.categorical_encoders[name].fit(data[name])

        # encode them
        for name in dict.fromkeys(group_ids_to_encode + self.flat_categoricals):
            # targets and its lagged versions are handled separetely
            if name not in self.target_names and name not in self.lagged_targets:
                data[name] = self.transform_values(
                    name, data[name], inverse=False, ignore_na=name in self.lagged_variables
                )

        # save special variables
        assert "__time_idx__" not in data.columns, "__time_idx__ is a protected column and must not be present in data"
        data["__time_idx__"] = data[self.time_idx]  # save unscaled
        for target in self.target_names:
            assert (
                f"__target__{target}" not in data.columns
            ), f"__target__{target} is a protected column and must not be present in data"
            data[f"__target__{target}"] = data[target]
        if self.weight is not None:
            data["__weight__"] = data[self.weight]

        # train target normalizer
        if self.target_normalizer is not None:

            # fit target normalizer
            try:
                check_is_fitted(self.target_normalizer)
            except NotFittedError:
                if isinstance(self.target_normalizer, EncoderNormalizer):
                    self.target_normalizer.fit(data[self.target])
                elif isinstance(self.target_normalizer, (GroupNormalizer, MultiNormalizer)):
                    self.target_normalizer.fit(data[self.target], data)
                else:
                    self.target_normalizer.fit(data[self.target])

            # transform target
            if isinstance(self.target_normalizer, EncoderNormalizer):
                # we approximate the scales and target transformation by assuming one
                # transformation over the entire time range but by each group
                common_init_args = [
                    name
                    for name in inspect.signature(GroupNormalizer.__init__).parameters.keys()
                    if name in inspect.signature(EncoderNormalizer.__init__).parameters.keys()
                    and name not in ["data", "self"]
                ]
                copy_kwargs = {name: getattr(self.target_normalizer, name) for name in common_init_args}
                normalizer = GroupNormalizer(groups=self.group_ids, **copy_kwargs)
                data[self.target], scales = normalizer.fit_transform(data[self.target], data, return_norm=True)

            elif isinstance(self.target_normalizer, GroupNormalizer):
                data[self.target], scales = self.target_normalizer.transform(data[self.target], data, return_norm=True)

            elif isinstance(self.target_normalizer, MultiNormalizer):
                transformed, scales = self.target_normalizer.transform(data[self.target], data, return_norm=True)

                for idx, target in enumerate(self.target_names):
                    data[target] = transformed[idx]

                    if isinstance(self.target_normalizer[idx], NaNLabelEncoder):
                        # overwrite target because it requires encoding (continuous targets should not be normalized)
                        data[f"__target__{target}"] = data[target]

            elif isinstance(self.target_normalizer, NaNLabelEncoder):
                data[self.target] = self.target_normalizer.transform(data[self.target])
                # overwrite target because it requires encoding (continuous targets should not be normalized)
                data[f"__target__{self.target}"] = data[self.target]
                scales = None

            else:
                data[self.target], scales = self.target_normalizer.transform(data[self.target], return_norm=True)

            # add target scales
            if self.add_target_scales:
                if not isinstance(self.target_normalizer, MultiNormalizer):
                    scales = [scales]
                for target_idx, target in enumerate(self.target_names):
                    if not isinstance(self.target_normalizers[target_idx], NaNLabelEncoder):
                        for scale_idx, name in enumerate(["center", "scale"]):
                            feature_name = f"{target}_{name}"
                            assert (
                                feature_name not in data.columns
                            ), f"{feature_name} is a protected column and must not be present in data"
                            data[feature_name] = scales[target_idx][:, scale_idx].squeeze()
                            if feature_name not in self.reals:
                                self.static_reals.append(feature_name)

        # rescale continuous variables apart from target
        for name in self.reals:
            if name in self.target_names or name in self.lagged_variables:
                # lagged variables are only transformed - not fitted
                continue
            elif name not in self.scalers:
                self.scalers[name] = StandardScaler().fit(data[[name]])
            elif self.scalers[name] is not None:
                try:
                    check_is_fitted(self.scalers[name])
                except NotFittedError:
                    if isinstance(self.scalers[name], GroupNormalizer):
                        self.scalers[name] = self.scalers[name].fit(data[name], data)
                    else:
                        self.scalers[name] = self.scalers[name].fit(data[[name]])

        # encode after fitting
        for name in self.reals:
            # targets are handled separately
            transformer = self.get_transformer(name)
            if (
                name not in self.target_names
                and transformer is not None
                and not isinstance(transformer, EncoderNormalizer)
            ):
                data[name] = self.transform_values(name, data[name], data=data, inverse=False)

        # encode lagged categorical targets
        for name in self.lagged_targets:
            # normalizer only now available
            if name in self.flat_categoricals:
                data[name] = self.transform_values(name, data[name], inverse=False, ignore_na=True)

        # encode constant values
        self.encoded_constant_fill_strategy = {}
        for name, value in self.constant_fill_strategy.items():
            if name in self.target_names:
                self.encoded_constant_fill_strategy[f"__target__{name}"] = value
            self.encoded_constant_fill_strategy[name] = self.transform_values(
                name, np.array([value]), data=data, inverse=False
            )[0]

        # shorten data by maximum of lagged sequences to avoid NA values - shorten only after encoding
        if self.max_lag > 0:
            # negative tail implementation as .groupby().tail(-self.max_lag) is not implemented in pandas
            g = data.groupby(self._group_ids, observed=True)
            data = g._selected_obj[g.cumcount() >= self.max_lag]
        return data

这个函数被用于 Prophet 模型中对输入数据进行预处理。它包含以下几个步骤:
对连续变量进行缩放
为数据添加滞后项。
对组 ID 进行编码。
对分类特征进行编码。
保存特殊变量(time_idx、__target__name 和 weight)。
训练目标变量的标准化器。
转换目标变量。
让我们逐步详细解释这些步骤:

对连续变量进行缩放,对分类变量进行编码,然后将目标和权重设置到一边。

为数据添加滞后项:对于在 lags 参数中指定的每个变量,该步骤会将变量的滞后版本添加到数据中。滞后版本的数量由 lags 参数的值决定。此步骤使用 _get_lagged_names 方法为变量的滞后版本创建新的列名。将滞后值添加到数据中。遍历每个滞后变量,在数据中添加新的列,新列的值是原变量经过滞后处理后的结果。

对组 ID 进行编码:对于每个组 ID 的名称和相应的新名称,用现有的编码器进行编码。该步骤对组 ID 进行编码,组 ID 在 group_ids 参数中指定。编码方式是将每个组 ID 映射到一个整数。该步骤使用 pandas 的 factorize 函数进行编码。

对分类特征进行编码:为了保证后续的组归一化器依赖于已编码的分类变量,首先对分类变量进行编码,然后在编码后的数据上进行组归一化。该步骤对分类特征进行编码,分类特征在 categorical_features 参数中指定。编码方式是使用 one-hot 编码,即将每个分类特征拆分为多个二元特征。该步骤使用 pandas 的 get_dummies 函数进行编码。

保存特殊变量:该步骤将特殊变量(time_idx、__target__name 和 weight)保存到一个字典中,以便在之后的步骤中使用。保证 “time_idx” 和 “target” 前缀的列名不会出现在数据中。保存目标和权重,如果存在。

对目标变量进行归一化处理。如果使用了目标变量的归一化器,首先进行拟合操作,然后在训练集上进行归一化。如果使用了 GroupNormalizer,将在每个组中对目标进行归一化。如果使用了 EncoderNormalizer,则对目标进行编码后再进行归一化。

训练目标变量的标准化器:该步骤使用目标变量(__target__name)的值训练一个标准化器,以将目标变量缩放到指定的范围内。标准化器使用 sklearn 的 StandardScaler 类实现。

转换目标变量:该步骤使用训练好的标准化器将目标变量缩放到指定的范围内。这一步骤也会将目标变量的名称从 __target__name 更改为 y。

这个函数的输入是一个 Pandas DataFrame 对象,输出也是一个 Pandas DataFrame 对象。在预处理数据之后,输出的 DataFrame 可以用于模型训练。

def get_transformer(self, name: str, group_id: bool = False):
    """
    Get transformer for variable.

    Args:
        name (str): variable name
        group_id (bool, optional): If the passed name refers to a group id (different encoders are used for these).
            Defaults to False.

    Returns:
        transformer
    """
    if group_id:
        name = self._group_ids_mapping[name]
    elif name in self.lagged_variables:  # recover transformer fitted on non-lagged variable
        name = self.lagged_variables[name]

    if name in self.flat_categoricals + self.group_ids + self._group_ids:
        name = self.variable_to_group_mapping.get(name, name)  # map name to encoder

        # take target normalizer if required
        if name in self.target_names:
            transformer = self.target_normalizers[self.target_names.index(name)]
        else:
            transformer = self.categorical_encoders.get(name, None)
        return transformer

    elif name in self.reals:
        # take target normalizer if required
        if name in self.target_names:
            transformer = self.target_normalizers[self.target_names.index(name)]
        else:
            transformer = self.scalers.get(name, None)
        return transformer
    else:
        return None

这段代码是 Prophet 模型中的一个函数,用于获取变量的转换器(transformer)。下面逐行解释:

  • def get_transformer(self, name: str, group_id: bool = False)::定义函数 get_transformer,有两个输入参数 name 和 group_id,group_id 默认为 False。
  • if group_id::如果 group_id 为 True,则需要将 name 映射到相应的编码器(encoder)上。
  • name = self._group_ids_mapping[name]:将 name 映射到相应的编码器上。
  • elif name in self.lagged_variables::如果 name 是滞后变量,则需要恢复到未滞后变量上已经拟合的转换器。
    name = self.lagged_variables[name]:恢复到未滞后变量上已经拟合的转换器。
  • if name in self.flat_categoricals + self.group_ids + self._group_ids::如果 name 是平面分类变量、分组变量或者组别变量之一,则需要将 name 映射到相应的编码器上。
    name = self.variable_to_group_mapping.get(name, name):将 name 映射到相应的编码器上。
  • if name in self.target_names::如果 name 是目标变量之一,则需要将其映射到目标归一化器上。
    transformer = self.target_normalizers[self.target_names.index(name)]:将 name 映射到目标归一化器上。
  • else::如果不是目标变量之一,则需要将其映射到相应的编码器上。
  • transformer = self.categorical_encoders.get(name, None):将 name 映射到相应的编码器上。
  • return transformer:返回变量的转换器。
  • elif name in self.reals::如果变量名 name 在 self.reals 列表中,即为连续型变量,则获取该变量的归一化器(transformer)。如果该变量是目标变量之一(在 self.target_names 列表中),则获取目标变量归一化器(self.target_normalizers);否则获取连续型变量的归一化器(self.scalers)。如果 name 不是任何变量类型,则返回 None。
    def transform_values(
        self,
        name: str,
        values: Union[pd.Series, torch.Tensor, np.ndarray],
        data: pd.DataFrame = None,
        inverse=False,
        group_id: bool = False,
        **kwargs,
    ) -> np.ndarray:
        """
        Scale and encode values.

        Args:
            name (str): name of variable
            values (Union[pd.Series, torch.Tensor, np.ndarray]): values to encode/scale
            data (pd.DataFrame, optional): extra data used for scaling (e.g. dataframe with groups columns).
                Defaults to None.
            inverse (bool, optional): if to conduct inverse transformation. Defaults to False.
            group_id (bool, optional): If the passed name refers to a group id (different encoders are used for these).
                Defaults to False.
            **kwargs: additional arguments for transform/inverse_transform method

        Returns:
            np.ndarray: (de/en)coded/(de)scaled values
        """
        transformer = self.get_transformer(name, group_id=group_id)
        if transformer is None:
            return values
        if inverse:
            transform = transformer.inverse_transform
        else:
            transform = transformer.transform

        if group_id:
            name = self._group_ids_mapping[name]
        # remaining categories
        if name in self.flat_categoricals + self.group_ids + self._group_ids:
            return transform(values, **kwargs)

        # reals
        elif name in self.reals:
            if isinstance(transformer, GroupNormalizer):
                return transform(values, data, **kwargs)
            elif isinstance(transformer, EncoderNormalizer):
                return transform(values, **kwargs)
            else:
                if isinstance(values, pd.Series):
                    values = values.to_frame()
                    return np.asarray(transform(values, **kwargs)).reshape(-1)
                else:
                    values = values.reshape(-1, 1)
                    return transform(values, **kwargs).reshape(-1)
        else:
            return values

这段代码用于对输入的数据进行归一化和编码操作。

函数名为 transform_values,包含六个输入参数:name 表示数据的变量名,values 表示数据的值,data 表示额外用于归一化的数据,inverse 表示是否对数据进行反转换,group_id 表示是否对分组数据进行编码,**kwargs 表示可变数量的参数列表。

首先,通过调用 get_transformer 函数获取数据的归一化器或编码器。如果获取的结果为 None,则表示该变量不需要进行归一化和编码操作,直接返回原始数据。否则,根据 inverse 参数判断是否需要进行反转换,选择调用归一化器或编码器的 transform 或 inverse_transform 函数。如果需要对分组数据进行编码,需要将变量名转换为相应的编码器,并将输入数据传入 transform 函数中进行编码。对于其他数据类型,则需要先将输入数据按照要求的格式进行转换,然后调用相应的函数进行归一化或编码,并最终返回结果。

总之,这段代码实现了模型中对数据进行归一化和编码操作的核心逻辑。

 def _data_to_tensors(self, data: pd.DataFrame) -> Dict[str, torch.Tensor]:
        """
        Convert data to tensors for faster access with :py:meth:`~__getitem__`.

        Args:
            data (pd.DataFrame): preprocessed data

        Returns:
            Dict[str, torch.Tensor]: dictionary of tensors for continous, categorical data, groups, target and
                time index
        """

        index = check_for_nonfinite(
            torch.tensor(data[self._group_ids].to_numpy(np.int64), dtype=torch.int64), self.group_ids
        )
        time = check_for_nonfinite(
            torch.tensor(data["__time_idx__"].to_numpy(np.int64), dtype=torch.int64), self.time_idx
        )

        # categorical covariates
        categorical = check_for_nonfinite(
            torch.tensor(data[self.flat_categoricals].to_numpy(np.int64), dtype=torch.int64), self.flat_categoricals
        )

        # get weight
        if self.weight is not None:
            weight = check_for_nonfinite(
                torch.tensor(
                    data["__weight__"].to_numpy(dtype=np.float64),
                    dtype=torch.float,
                ),
                self.weight,
            )
        else:
            weight = None

        # get target
        if isinstance(self.target_normalizer, NaNLabelEncoder):
            target = [
                check_for_nonfinite(
                    torch.tensor(data[f"__target__{self.target}"].to_numpy(dtype=np.int64), dtype=torch.long),
                    self.target,
                )
            ]
        else:
            if not isinstance(self.target, str):  # multi-target
                target = [
                    check_for_nonfinite(
                        torch.tensor(
                            data[f"__target__{name}"].to_numpy(
                                dtype=[np.float64, np.int64][data[name].dtype.kind in "bi"]
                            ),
                            dtype=[torch.float, torch.long][data[name].dtype.kind in "bi"],
                        ),
                        name,
                    )
                    for name in self.target_names
                ]
            else:
                target = [
                    check_for_nonfinite(
                        torch.tensor(data[f"__target__{self.target}"].to_numpy(dtype=np.float64), dtype=torch.float),
                        self.target,
                    )
                ]

        # continuous covariates
        continuous = check_for_nonfinite(
            torch.tensor(data[self.reals].to_numpy(dtype=np.float64), dtype=torch.float), self.reals
        )

        tensors = dict(
            reals=continuous, categoricals=categorical, groups=index, target=target, weight=weight, time=time
        )

        return tensors

这段代码实现了将经过预处理的数据转换为张量(tensors),以便在模型训练和预测时更快地访问数据。它的输入是一个预处理后的 Pandas DataFrame,包含了所有模型需要的输入变量,如分类变量、连续变量、组 ID、时间戳、目标变量和样本权重。函数将数据转换为对应的 PyTorch 张量(tensor),并将其存储在一个字典中返回,每个键值对应一种类型的输入变量。

具体来说,函数首先将组 ID、时间戳和分类变量转换为整数类型(int64)的张量,并对其进行 NaN 值检查,以确保数据的正确性。接着,如果有样本权重,将其转换为浮点类型(float)的张量。如果有目标变量,则将其转换为相应的张量,根据是否是多目标问题,可能需要多次转换。最后,将连续变量转换为浮点类型(float)的张量。所有这些张量被存储在一个字典中,每个键代表一种类型的输入变量,以便在训练和预测过程中更方便地使用。

@property
    def categoricals(self) -> List[str]:
        """
        Categorical variables as used for modelling.

        Returns:
            List[str]: list of variables
        """
        return self.static_categoricals + self.time_varying_known_categoricals + self.time_varying_unknown_categoricals

    @property
    def flat_categoricals(self) -> List[str]:
        """
        Categorical variables as defined in input data.

        Returns:
            List[str]: list of variables
        """
        categories = []
        for name in self.categoricals:
            if name in self.variable_groups:
                categories.extend(self.variable_groups[name])
            else:
                categories.append(name)
        return categories

    @property
    def variable_to_group_mapping(self) -> Dict[str, str]:
        """
        Mapping from categorical variables to variables in input data.

        Returns:
            Dict[str, str]: dictionary mapping from :py:meth:`~categorical` to :py:meth:`~flat_categoricals`.
        """
        groups = {}
        for group_name, sublist in self.variable_groups.items():
            groups.update({name: group_name for name in sublist})
        return groups

    @property
    def reals(self) -> List[str]:
        """
        Continous variables as used for modelling.

        Returns:
            List[str]: list of variables
        """
        return self.static_reals + self.time_varying_known_reals + self.time_varying_unknown_reals

    @property
    @lru_cache(None)
    def target_names(self) -> List[str]:
        """
        List of targets.

        Returns:
            List[str]: list of targets
        """
        if self.multi_target:
            return self.target
        else:
            return [self.target]

    @property
    def multi_target(self) -> bool:
        """
        If dataset encodes one or multiple targets.

        Returns:
            bool: true if multiple targets
        """
        return isinstance(self.target, (list, tuple))

    @property
    def target_normalizers(self) -> List[TorchNormalizer]:
        """
        List of target normalizers aligned with ``target_names``.

        Returns:
            List[TorchNormalizer]: list of target normalizers
        """
        if isinstance(self.target_normalizer, MultiNormalizer):
            target_normalizers = self.target_normalizer.normalizers
        else:
            target_normalizers = [self.target_normalizer]
        return target_normalizers

    def get_parameters(self) -> Dict[str, Any]:
        """
        Get parameters that can be used with :py:meth:`~from_parameters` to create a new dataset with the same scalers.

        Returns:
            Dict[str, Any]: dictionary of parameters
        """
        kwargs = {
            name: getattr(self, name)
            for name in inspect.signature(self.__class__.__init__).parameters.keys()
            if name not in ["data", "self"]
        }
        kwargs["categorical_encoders"] = self.categorical_encoders
        kwargs["scalers"] = self.scalers
        return kwargs

    @classmethod
    def from_dataset(
        cls, dataset, data: pd.DataFrame, stop_randomization: bool = False, predict: bool = False, **update_kwargs
    ):
        """
        Generate dataset with different underlying data but same variable encoders and scalers, etc.

        Calls :py:meth:`~from_parameters` under the hood.

        Args:
            dataset (TimeSeriesDataSet): dataset from which to copy parameters
            data (pd.DataFrame): data from which new dataset will be generated
            stop_randomization (bool, optional): If to stop randomizing encoder and decoder lengths,
                e.g. useful for validation set. Defaults to False.
            predict (bool, optional): If to predict the decoder length on the last entries in the
                time index (i.e. one prediction per group only). Defaults to False.
            **kwargs: keyword arguments overriding parameters in the original dataset

        Returns:
            TimeSeriesDataSet: new dataset
        """
        return cls.from_parameters(
            dataset.get_parameters(), data, stop_randomization=stop_randomization, predict=predict, **update_kwargs
        )

categoricals属性:获取所有用于建模的分类变量的列表,包括静态分类变量、已知时间变化的分类变量和未知时间变化的分类变量。

flat_categoricals属性:获取输入数据中所有分类变量的列表,包括每个分类变量的所有可能值。

variable_to_group_mapping属性:获取从模型分类变量到输入数据分类变量的映射字典。

reals属性:获取所有用于建模的连续变量的列表,包括静态连续变量、已知时间变化的连续变量和未知时间变化的连续变量。

target_names属性:获取所有目标变量的列表,如果数据集只有一个目标,则将其转化为列表。

multi_target属性:检查数据集是否包含多个目标变量。

target_normalizers属性:获取目标变量归一化器的列表,与target_names对应。

get_parameters方法:获取当前数据集的参数,以便使用from_parameters方法创建具有相同缩放器的新数据集。

from_dataset类方法:创建具有不同底层数据但相同变量编码器和缩放器等的新数据集。该方法在内部调用from_parameters方法。可以通过传递stop_randomization和predict参数,以及关键字参数覆盖原始数据集中的参数来自定义新数据集。

  @classmethod
    def from_parameters(
        cls,
        parameters: Dict[str, Any],
        data: pd.DataFrame,
        stop_randomization: bool = None,
        predict: bool = False,
        **update_kwargs,
    ):
        """
        Generate dataset with different underlying data but same variable encoders and scalers, etc.

        Args:
            parameters (Dict[str, Any]): dataset parameters which to use for the new dataset
            data (pd.DataFrame): data from which new dataset will be generated
            stop_randomization (bool, optional): If to stop randomizing encoder and decoder lengths,
                e.g. useful for validation set. Defaults to False.
            predict (bool, optional): If to predict the decoder length on the last entries in the
                time index (i.e. one prediction per group only). Defaults to False.
            **kwargs: keyword arguments overriding parameters

        Returns:
            TimeSeriesDataSet: new dataset
        """
        parameters = deepcopy(parameters)
        if predict:
            if stop_randomization is None:
                stop_randomization = True
            elif not stop_randomization:
                warnings.warn(
                    "If predicting, no randomization should be possible - setting stop_randomization=True", UserWarning
                )
                stop_randomization = True
            parameters["min_prediction_length"] = parameters["max_prediction_length"]
            parameters["predict_mode"] = True
        elif stop_randomization is None:
            stop_randomization = False

        if stop_randomization:
            parameters["randomize_length"] = None
        parameters.update(update_kwargs)

        new = cls(data, **parameters)
        return new

这段代码定义了一个类方法 from_parameters,该方法用于生成一个具有不同潜在数据但具有相同变量编码器和缩放器等特征的数据集。

这个方法有几个参数:

parameters:一个字典,包含用于生成新数据集的参数。
data:一个 Pandas DataFrame,用于生成新数据集。
stop_randomization:一个布尔值,指示是否停止随机化编码器和解码器长度。默认为 None。
predict:一个布尔值,指示是否在时间索引的最后条目上预测解码器长度(即每个组仅预测一次)。默认为 False。
**update_kwargs:可变数量的关键字参数,用于覆盖 parameters 中的值。
这个方法的主要步骤是:

复制传入的 parameters 参数。
如果 predict 为 True,则更新 stop_randomization 的值为 True。如果 stop_randomization 不为 None 且为 False,则发出一个警告,并将 stop_randomization 设置为 True。设置参数 “min_prediction_length” 为 “max_prediction_length”,参数 “predict_mode” 为 True。
如果 stop_randomization 为 None,则将其设置为 False。
如果 stop_randomization 为 True,则将参数 “randomize_length” 设置为 None。
使用更新后的参数和传入的数据创建一个新的 TimeSeriesDataSet 对象,并返回。

    def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFrame:
        """
        Create index of samples.

        Args:
            data (pd.DataFrame): preprocessed data
            predict_mode (bool): if to create one same per group with prediction length equals ``max_decoder_length``

        Returns:
            pd.DataFrame: index dataframe for timesteps and index dataframe for groups.
                It contains a list of all possible subsequences.
        """
        g = data.groupby(self._group_ids, observed=True)

        df_index_first = g["__time_idx__"].transform("nth", 0).to_frame("time_first")
        df_index_last = g["__time_idx__"].transform("nth", -1).to_frame("time_last")
        df_index_diff_to_next = -g["__time_idx__"].diff(-1).fillna(-1).astype(int).to_frame("time_diff_to_next")
        df_index = pd.concat([df_index_first, df_index_last, df_index_diff_to_next], axis=1)
        df_index["index_start"] = np.arange(len(df_index))
        df_index["time"] = data["__time_idx__"]
        df_index["count"] = (df_index["time_last"] - df_index["time_first"]).astype(int) + 1
        sequence_ids = g.ngroup()
        df_index["sequence_id"] = sequence_ids

        min_sequence_length = self.min_prediction_length + self.min_encoder_length
        max_sequence_length = self.max_prediction_length + self.max_encoder_length

        # calculate maximum index to include from current index_start
        max_time = (df_index["time"] + max_sequence_length - 1).clip(upper=df_index["count"] + df_index.time_first - 1)

        # if there are missing timesteps, we cannot say directly what is the last timestep to include
        # therefore we iterate until it is found
        if (df_index["time_diff_to_next"] != 1).any():
            assert (
                self.allow_missing_timesteps
            ), "Time difference between steps has been idenfied as larger than 1 - set allow_missing_timesteps=True"

        df_index["index_end"], missing_sequences = _find_end_indices(
            diffs=df_index.time_diff_to_next.to_numpy(),
            max_lengths=(max_time - df_index.time).to_numpy() + 1,
            min_length=min_sequence_length,
        )
        # add duplicates but mostly with shorter sequence length for start of timeseries
        # while the previous steps have ensured that we start a sequence on every time step, the missing_sequences
        # ensure that there is a sequence that finishes on every timestep
        if len(missing_sequences) > 0:
            shortened_sequences = df_index.iloc[missing_sequences[:, 0]].assign(index_end=missing_sequences[:, 1])

            # concatenate shortened sequences
            df_index = pd.concat([df_index, shortened_sequences], axis=0, ignore_index=True)

        # filter out where encode and decode length are not satisfied
        df_index["sequence_length"] = df_index["time"].iloc[df_index["index_end"]].to_numpy() - df_index["time"] + 1

        # filter too short sequences
        df_index = df_index[
            # sequence must be at least of minimal prediction length
            lambda x: (x.sequence_length >= min_sequence_length)
            &
            # prediction must be for after minimal prediction index + length of prediction
            (x["sequence_length"] + x["time"] >= self.min_prediction_idx + self.min_prediction_length)
        ]

        if predict_mode:  # keep longest element per series (i.e. the first element that spans to the end of the series)
            # filter all elements that are longer than the allowed maximum sequence length
            df_index = df_index[
                lambda x: (x["time_last"] - x["time"] + 1 <= max_sequence_length)
                & (x["sequence_length"] >= min_sequence_length)
            ]
            # choose longest sequence
            df_index = df_index.loc[df_index.groupby("sequence_id").sequence_length.idxmax()]

        # check that all groups/series have at least one entry in the index
        if not sequence_ids.isin(df_index.sequence_id).all():
            missing_groups = data.loc[~sequence_ids.isin(df_index.sequence_id), self._group_ids].drop_duplicates()
            # decode values
            for name, id in self._group_ids_mapping.items():
                missing_groups[id] = self.transform_values(name, missing_groups[id], inverse=True, group_id=True)
            warnings.warn(
                "Min encoder length and/or min_prediction_idx and/or min prediction length and/or lags are "
                "too large for "
                f"{len(missing_groups)} series/groups which therefore are not present in the dataset index. "
                "This means no predictions can be made for those series. "
                f"First 10 removed groups: {list(missing_groups.iloc[:10].to_dict(orient='index').values())}",
                UserWarning,
            )
        assert (
            len(df_index) > 0
        ), "filters should not remove entries all entries - check encoder/decoder lengths and lags"

        return df_index

这是一个预处理数据的函数,目的是为了构建一个索引来帮助模型生成序列预测。函数输入一个pandas DataFrame类型的数据和一个布尔类型的predict_mode变量,返回一个包含时间步和组索引数据帧的数据帧。其中,时间步是所有可能子序列的列表,而组则是不同的时间序列。具体而言,这个函数实现了以下步骤:

1.根据数据中的组标识符对数据进行分组。

2.使用分组数据的第一个和最后一个时间步,以及相邻时间步之间的时间差,构建一个包含索引开始、索引结束、时间和计数列的数据帧。

3.计算序列的最小长度和最大长度,并计算从当前索引开始应包含的最大索引。

4.如果存在缺失的时间步,需要进行迭代,直到找到缺失的时间步。

5.将缺失的序列添加到数据帧中,并缩短其长度。

6.过滤长度不足的序列。

7.如果predict_mode为True,则仅保留每个序列中最长的元素。

8.确保索引中包含所有的组。

9.返回索引数据帧。

    def filter(self, filter_func: Callable, copy: bool = True) -> "TimeSeriesDataSet":
        """
        Filter subsequences in dataset.

        Uses interpretable version of index :py:meth:`~decoded_index`
        to filter subsequences in dataset.

        Args:
            filter_func (Callable): function to filter. Should take :py:meth:`~decoded_index`
                dataframe as only argument which contains group ids and time index columns.
            copy (bool): if to return copy of dataset or filter inplace.

        Returns:
            TimeSeriesDataSet: filtered dataset
        """
        # calculate filter
        filtered_index = self.index[np.asarray(filter_func(self.decoded_index))]
        # raise error if filter removes all entries
        if len(filtered_index) == 0:
            raise ValueError("After applying filter no sub-sequences left in dataset")
        if copy:
            dataset = _copy(self)
            dataset.index = filtered_index
            return dataset
        else:
            self.index = filtered_index
            return self

    @property
    def decoded_index(self) -> pd.DataFrame:
        """
        Get interpretable version of index.

        DataFrame contains
        - group_id columns in original encoding
        - time_idx_first column: first time index of subsequence
        - time_idx_last columns: last time index of subsequence
        - time_idx_first_prediction columns: first time index which is in decoder

        Returns:
            pd.DataFrame: index that can be understood in terms of original data
        """
        # get dataframe to filter
        index_start = self.index["index_start"].to_numpy()
        index_last = self.index["index_end"].to_numpy()
        index = (
            # get group ids in order of index
            pd.DataFrame(self.data["groups"][index_start].numpy(), columns=self.group_ids)
            # to original values
            .apply(lambda x: self.transform_values(name=x.name, values=x, group_id=True, inverse=True))
            # add time index
            .assign(
                time_idx_first=self.data["time"][index_start].numpy(),
                time_idx_last=self.data["time"][index_last].numpy(),
                # prediction index is last time index - decoder length + 1
                time_idx_first_prediction=lambda x: x.time_idx_last
                - self.calculate_decoder_length(
                    time_last=x.time_idx_last, sequence_length=x.time_idx_last - x.time_idx_first + 1
                )
                + 1,
            )
        )
        return index

这段代码是一个 Python 类的方法 filter,用于过滤时间序列数据集中的子序列。具体来说,这个方法接受两个参数:filter_func 和 copy。filter_func 是一个可调用对象,接受一个 DataFrame 作为输入,其中包含分组 ID 和时间索引列。copy 参数是一个布尔值,用于指定是否返回数据集的副本或者就地修改数据集。

方法的返回值是一个经过过滤后的时间序列数据集对象。如果 copy 为 True,则返回一个数据集的副本,否则就在原始数据集上进行修改。如果过滤后没有子序列留下,则会抛出 ValueError。

这个类还定义了一个 decoded_index 的属性方法,用于获取解释性的索引,其中包含分组 ID、子序列的起始和结束时间索引以及解码器的第一个时间索引。具体而言,这个方法将数据集的索引转换为可以理解为原始数据的形式。这个方法的返回值是一个 DataFrame 对象。

    def plot_randomization(
        self, betas: Tuple[float, float] = None, length: int = None, min_length: int = None
    ) -> Tuple[plt.Figure, torch.Tensor]:
        """
        Plot expected randomized length distribution.

        Args:
            betas (Tuple[float, float], optional): Tuple of betas, e.g. ``(0.2, 0.05)`` to use for randomization.
                Defaults to ``randomize_length`` of dataset.
            length (int, optional): . Defaults to ``max_encoder_length``.
            min_length (int, optional): [description]. Defaults to ``min_encoder_length``.

        Returns:
            Tuple[plt.Figure, torch.Tensor]: tuple of figure and histogram based on 1000 samples
        """
        if betas is None:
            betas = self.randomize_length
        if length is None:
            length = self.max_encoder_length
        if min_length is None:
            min_length = self.min_encoder_length
        probabilities = Beta(betas[0], betas[1]).sample((1000,))

        lengths = ((length - min_length) * probabilities).round() + min_length

        fig, ax = plt.subplots()
        ax.hist(lengths)
        return fig, lengths

    def __len__(self) -> int:
        """
        Length of dataset.

        Returns:
            int: length
        """
        return self.index.shape[0]

    def set_overwrite_values(
        self, values: Union[float, torch.Tensor], variable: str, target: Union[str, slice] = "decoder"
    ) -> None:
        """
        Convenience method to quickly overwrite values in decoder or encoder (or both) for a specific variable.

        Args:
            values (Union[float, torch.Tensor]): values to use for overwrite.
            variable (str): variable whose values should be overwritten.
            target (Union[str, slice], optional): positions to overwrite. One of "decoder", "encoder" or "all" or
                a slice object which is directly used to overwrite indices, e.g. ``slice(-5, None)`` will overwrite
                the last 5 values. Defaults to "decoder".
        """
        values = torch.tensor(self.transform_values(variable, np.asarray(values).reshape(-1), inverse=False)).squeeze()
        assert target in [
            "all",
            "decoder",
            "encoder",
        ], f"target has be one of 'all', 'decoder' or 'encoder' but target={target} instead"

        if variable in self.static_categoricals or variable in self.static_categoricals:
            target = "all"

        if variable in self.target_names:
            raise NotImplementedError("Target variable is not supported")
        if self.weight is not None and self.weight == variable:
            raise NotImplementedError("Weight variable is not supported")
        if isinstance(self.scalers.get(variable, self.categorical_encoders.get(variable)), TorchNormalizer):
            raise NotImplementedError("TorchNormalizer (e.g. GroupNormalizer) is not supported")

        if self._overwrite_values is None:
            self._overwrite_values = {}
        self._overwrite_values.update(dict(values=values, variable=variable, target=target))

    def reset_overwrite_values(self) -> None:
        """
        Reset values used to override sample features.
        """
        self._overwrite_values = None

    def calculate_decoder_length(
        self,
        time_last: Union[int, pd.Series, np.ndarray],
        sequence_length: Union[int, pd.Series, np.ndarray],
    ) -> Union[int, pd.Series, np.ndarray]:
        """
        Calculate length of decoder.

        Args:
            time_last (Union[int, pd.Series, np.ndarray]): last time index of the sequence
            sequence_length (Union[int, pd.Series, np.ndarray]): total length of the sequence

        Returns:
            Union[int, pd.Series, np.ndarray]: decoder length(s)
        """
        if isinstance(time_last, int):
            decoder_length = min(
                time_last - (self.min_prediction_idx - 1),  # not going beyond min prediction idx
                self.max_prediction_length,  # maximum prediction length
                sequence_length - self.min_encoder_length,  # sequence length - min decoder length
            )
        else:
            decoder_length = np.min(
                [
                    time_last - (self.min_prediction_idx - 1),
                    sequence_length - self.min_encoder_length,
                ],
                axis=0,
            ).clip(max=self.max_prediction_length)
        return decoder_length

这些代码是一个类的方法,该类是用于时间序列预测的数据集。下面是每个方法的详细功能:

plot_randomization: 这个方法用于绘制预期随机长度分布的直方图。可以通过传递 betas 参数来指定随机化长度的范围。默认情况下,将使用数据集的 randomize_length 值作为 betas 参数。length 和 min_length 参数分别指定编码器序列的最大和最小长度。方法返回一个元组,包括绘制的图形和基于1000个样本的直方图。

len: 这个方法返回数据集的长度。

set_overwrite_values: 这个方法用于快速覆盖编码器或解码器(或两者)中特定变量的值。values 参数指定要使用的值,variable 参数指定要覆盖值的变量,target 参数指定要覆盖哪些位置。可选的 target 参数包括 “decoder”,“encoder” 或 “all”,或者直接使用 slice 对象来覆盖索引。方法不返回任何内容,但是会在数据集对象中存储覆盖的值。

reset_overwrite_values: 这个方法用于重置用于覆盖样本特征的值。

calculate_decoder_length: 这个方法用于计算解码器的长度。它基于最后一个时间索引、序列长度以及数据集的最小和最大解码器长度来计算解码器的长度。返回的是一个整数或数组,表示解码器的长度。

    def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
        """
        Get sample for model

        Args:
            idx (int): index of prediction (between ``0`` and ``len(dataset) - 1``)

        Returns:
            Tuple[Dict[str, torch.Tensor], torch.Tensor]: x and y for model
        """
        index = self.index.iloc[idx]
        # get index data
        data_cont = self.data["reals"][index.index_start : index.index_end + 1].clone()
        data_cat = self.data["categoricals"][index.index_start : index.index_end + 1].clone()
        time = self.data["time"][index.index_start : index.index_end + 1].clone()
        target = [d[index.index_start : index.index_end + 1].clone() for d in self.data["target"]]
        groups = self.data["groups"][index.index_start].clone()
        if self.data["weight"] is None:
            weight = None
        else:
            weight = self.data["weight"][index.index_start : index.index_end + 1].clone()
        # get target scale in the form of a list
        target_scale = self.target_normalizer.get_parameters(groups, self.group_ids)
        if not isinstance(self.target_normalizer, MultiNormalizer):
            target_scale = [target_scale]

        # fill in missing values (if not all time indices are specified
        sequence_length = len(time)
        if sequence_length < index.sequence_length:
            assert self.allow_missing_timesteps, "allow_missing_timesteps should be True if sequences have gaps"
            repetitions = torch.cat([time[1:] - time[:-1], torch.ones(1, dtype=time.dtype)])
            indices = torch.repeat_interleave(torch.arange(len(time)), repetitions)
            repetition_indices = torch.cat([torch.tensor([False], dtype=torch.bool), indices[1:] == indices[:-1]])

            # select data
            data_cat = data_cat[indices]
            data_cont = data_cont[indices]
            target = [d[indices] for d in target]
            if weight is not None:
                weight = weight[indices]

            # reset index
            if self.time_idx in self.reals:
                time_idx = self.reals.index(self.time_idx)
                data_cont[:, time_idx] = torch.linspace(
                    data_cont[0, time_idx], data_cont[-1, time_idx], len(target[0]), dtype=data_cont.dtype
                )

            # make replacements to fill in categories
            for name, value in self.encoded_constant_fill_strategy.items():
                if name in self.reals:
                    data_cont[repetition_indices, self.reals.index(name)] = value
                elif name in [f"__target__{target_name}" for target_name in self.target_names]:
                    target_pos = self.target_names.index(name[len("__target__") :])
                    target[target_pos][repetition_indices] = value
                elif name in self.flat_categoricals:
                    data_cat[repetition_indices, self.flat_categoricals.index(name)] = value
                elif name in self.target_names:  # target is just not an input value
                    pass
                else:
                    raise KeyError(f"Variable {name} is not known and thus cannot be filled in")

            sequence_length = len(target[0])

        # determine data window
        assert (
            sequence_length >= self.min_prediction_length
        ), "Sequence length should be at least minimum prediction length"
        # determine prediction/decode length and encode length
        decoder_length = self.calculate_decoder_length(time[-1], sequence_length)
        encoder_length = sequence_length - decoder_length
        assert (
            decoder_length >= self.min_prediction_length
        ), "Decoder length should be at least minimum prediction length"
        assert encoder_length >= self.min_encoder_length, "Encoder length should be at least minimum encoder length"

        if self.randomize_length is not None:  # randomization improves generalization
            # modify encode and decode lengths
            modifiable_encoder_length = encoder_length - self.min_encoder_length
            encoder_length_probability = Beta(self.randomize_length[0], self.randomize_length[1]).sample()

            # subsample a new/smaller encode length
            new_encoder_length = self.min_encoder_length + int(
                (modifiable_encoder_length * encoder_length_probability).round()
            )

            # extend decode length if possible
            new_decoder_length = min(decoder_length + (encoder_length - new_encoder_length), self.max_prediction_length)

            # select subset of sequence of new sequence
            if new_encoder_length + new_decoder_length < len(target[0]):
                data_cat = data_cat[encoder_length - new_encoder_length : encoder_length + new_decoder_length]
                data_cont = data_cont[encoder_length - new_encoder_length : encoder_length + new_decoder_length]
                target = [t[encoder_length - new_encoder_length : encoder_length + new_decoder_length] for t in target]
                encoder_length = new_encoder_length
                decoder_length = new_decoder_length

            # switch some variables to nan if encode length is 0
            if encoder_length == 0 and len(self.dropout_categoricals) > 0:
                data_cat[
                    :, [self.flat_categoricals.index(c) for c in self.dropout_categoricals]
                ] = 0  # zero is encoded nan

        assert decoder_length > 0, "Decoder length should be greater than 0"
        assert encoder_length >= 0, "Encoder length should be at least 0"

        if self.add_relative_time_idx:
            data_cont[:, self.reals.index("relative_time_idx")] = (
                torch.arange(-encoder_length, decoder_length, dtype=data_cont.dtype) / self.max_encoder_length
            )

        if self.add_encoder_length:
            data_cont[:, self.reals.index("encoder_length")] = (
                (encoder_length - 0.5 * self.max_encoder_length) / self.max_encoder_length * 2.0
            )

        # rescale target
        for idx, target_normalizer in enumerate(self.target_normalizers):
            if isinstance(target_normalizer, EncoderNormalizer):
                target_name = self.target_names[idx]
                # fit and transform
                target_normalizer.fit(target[idx][:encoder_length])
                # get new scale
                single_target_scale = target_normalizer.get_parameters()
                # modify input data
                if target_name in self.reals:
                    data_cont[:, self.reals.index(target_name)] = target_normalizer.transform(target[idx])
                if self.add_target_scales:
                    data_cont[:, self.reals.index(f"{target_name}_center")] = self.transform_values(
                        f"{target_name}_center", single_target_scale[0]
                    )[0]
                    data_cont[:, self.reals.index(f"{target_name}_scale")] = self.transform_values(
                        f"{target_name}_scale", single_target_scale[1]
                    )[0]
                # scale needs to be numpy to be consistent with GroupNormalizer
                target_scale[idx] = single_target_scale.numpy()

        # rescale covariates
        for name in self.reals:
            if name not in self.target_names and name not in self.lagged_variables:
                normalizer = self.get_transformer(name)
                if isinstance(normalizer, EncoderNormalizer):
                    # fit and transform
                    pos = self.reals.index(name)
                    normalizer.fit(data_cont[:encoder_length, pos])
                    # transform
                    data_cont[:, pos] = normalizer.transform(data_cont[:, pos])

        # also normalize lagged variables
        for name in self.reals:
            if name in self.lagged_variables:
                normalizer = self.get_transformer(name)
                if isinstance(normalizer, EncoderNormalizer):
                    pos = self.reals.index(name)
                    data_cont[:, pos] = normalizer.transform(data_cont[:, pos])

        # overwrite values
        if self._overwrite_values is not None:
            if isinstance(self._overwrite_values["target"], slice):
                positions = self._overwrite_values["target"]
            elif self._overwrite_values["target"] == "all":
                positions = slice(None)
            elif self._overwrite_values["target"] == "encoder":
                positions = slice(None, encoder_length)
            else:  # decoder
                positions = slice(encoder_length, None)

            if self._overwrite_values["variable"] in self.reals:
                idx = self.reals.index(self._overwrite_values["variable"])
                data_cont[positions, idx] = self._overwrite_values["values"]
            else:
                assert (
                    self._overwrite_values["variable"] in self.flat_categoricals
                ), "overwrite values variable has to be either in real or categorical variables"
                idx = self.flat_categoricals.index(self._overwrite_values["variable"])
                data_cat[positions, idx] = self._overwrite_values["values"]

        # weight is only required for decoder
        if weight is not None:
            weight = weight[encoder_length:]

        # if user defined target as list, output should be list, otherwise tensor
        if self.multi_target:
            encoder_target = [t[:encoder_length] for t in target]
            target = [t[encoder_length:] for t in target]
        else:
            encoder_target = target[0][:encoder_length]
            target = target[0][encoder_length:]
            target_scale = target_scale[0]

        return (
            dict(
                x_cat=data_cat,
                x_cont=data_cont,
                encoder_length=encoder_length,
                decoder_length=decoder_length,
                encoder_target=encoder_target,
                encoder_time_idx_start=time[0],
                groups=groups,
                target_scale=target_scale,
            ),
            (target, weight),
        )

这是一个Python的类方法,名称为__getitem__,用于索引访问类的对象实例。

这个方法接受一个整数类型的参数 idx,表示数据集中某个样本的索引。然后,从数据集中获取相应的样本数据,对这些数据进行处理和准备,以供模型使用。方法的返回值为一个元组,包含两个元素。第一个元素为字典类型的数据 x,其中包含模型需要的所有输入变量的值。第二个元素为张量类型的数据 y,表示该样本对应的目标变量的值。

具体来说,这个方法首先根据输入参数 idx 获取数据集中对应样本的索引信息 index。然后,从数据集中获取该样本对应的连续型变量数据 data_cont,分类型变量数据 data_cat,时间变量数据 time,目标变量数据 target,分组变量数据 groups 和权重变量数据 weight。如果权重变量不存在,则将其置为None。此外,还会获取目标变量的标准化参数 target_scale,并将其转换为列表格式。如果目标变量标准化器为MultiNormalizer,则target_scale即为列表本身,否则需要将其转换为列表。

接下来,对数据进行缺失值处理。如果时间序列存在缺失值,则根据allow_missing_timesteps参数设置,选择是否允许存在缺失值。如果允许,则需要对缺失值进行插值。具体来说,将缺失值所在的时间点拆分成多个子区间,然后对每个子区间进行线性插值,以填补缺失值。此外,还需要使用指定的填充策略对分类型变量的缺失值进行填充,以保证数据的完整性。

最后,根据处理后的数据,计算出编码器和解码器的长度,并根据参数randomize_length的设置,随机调整编码器和解码器的长度。然后,从处理后的数据中选择编码器和解码器需要的子序列,并将其作为模型的输入输出数据返回。

 @staticmethod
    def _collate_fn(
        batches: List[Tuple[Dict[str, torch.Tensor], torch.Tensor]]
    ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
        """
        Collate function to combine items into mini-batch for dataloader.

        Args:
            batches (List[Tuple[Dict[str, torch.Tensor], torch.Tensor]]): List of samples generated with
                :py:meth:`~__getitem__`.

        Returns:
            Tuple[Dict[str, torch.Tensor], Tuple[Union[torch.Tensor, List[torch.Tensor]], torch.Tensor]: minibatch
        """
        # collate function for dataloader
        # lengths
        encoder_lengths = torch.tensor([batch[0]["encoder_length"] for batch in batches], dtype=torch.long)
        decoder_lengths = torch.tensor([batch[0]["decoder_length"] for batch in batches], dtype=torch.long)

        # ids
        decoder_time_idx_start = (
            torch.tensor([batch[0]["encoder_time_idx_start"] for batch in batches], dtype=torch.long) + encoder_lengths
        )
        decoder_time_idx = decoder_time_idx_start.unsqueeze(1) + torch.arange(decoder_lengths.max()).unsqueeze(0)
        groups = torch.stack([batch[0]["groups"] for batch in batches])

        # features
        encoder_cont = rnn.pad_sequence(
            [batch[0]["x_cont"][:length] for length, batch in zip(encoder_lengths, batches)], batch_first=True
        )
        encoder_cat = rnn.pad_sequence(
            [batch[0]["x_cat"][:length] for length, batch in zip(encoder_lengths, batches)], batch_first=True
        )

        decoder_cont = rnn.pad_sequence(
            [batch[0]["x_cont"][length:] for length, batch in zip(encoder_lengths, batches)], batch_first=True
        )
        decoder_cat = rnn.pad_sequence(
            [batch[0]["x_cat"][length:] for length, batch in zip(encoder_lengths, batches)], batch_first=True
        )

        # target scale
        if isinstance(batches[0][0]["target_scale"], torch.Tensor):  # stack tensor
            target_scale = torch.stack([batch[0]["target_scale"] for batch in batches])
        elif isinstance(batches[0][0]["target_scale"], (list, tuple)):
            target_scale = []
            for idx in range(len(batches[0][0]["target_scale"])):
                if isinstance(batches[0][0]["target_scale"][idx], torch.Tensor):  # stack tensor
                    scale = torch.stack([batch[0]["target_scale"][idx] for batch in batches])
                else:
                    scale = torch.from_numpy(
                        np.array([batch[0]["target_scale"][idx] for batch in batches], dtype=np.float32),
                    )
                target_scale.append(scale)
        else:  # convert to tensor
            target_scale = torch.from_numpy(
                np.array([batch[0]["target_scale"] for batch in batches], dtype=np.float32),
            )

        # target and weight
        if isinstance(batches[0][1][0], (tuple, list)):
            target = [
                rnn.pad_sequence([batch[1][0][idx] for batch in batches], batch_first=True)
                for idx in range(len(batches[0][1][0]))
            ]
            encoder_target = [
                rnn.pad_sequence([batch[0]["encoder_target"][idx] for batch in batches], batch_first=True)
                for idx in range(len(batches[0][1][0]))
            ]
        else:
            target = rnn.pad_sequence([batch[1][0] for batch in batches], batch_first=True)
            encoder_target = rnn.pad_sequence([batch[0]["encoder_target"] for batch in batches], batch_first=True)

        if batches[0][1][1] is not None:
            weight = rnn.pad_sequence([batch[1][1] for batch in batches], batch_first=True)
        else:
            weight = None

        return (
            dict(
                encoder_cat=encoder_cat,
                encoder_cont=encoder_cont,
                encoder_target=encoder_target,
                encoder_lengths=encoder_lengths,
                decoder_cat=decoder_cat,
                decoder_cont=decoder_cont,
                decoder_target=target,
                decoder_lengths=decoder_lengths,
                decoder_time_idx=decoder_time_idx,
                groups=groups,
                target_scale=target_scale,
            ),
            (target, weight),
        )

这个函数是用来将一个列表中的多个样本组合成一个mini-batch的函数。下面是函数的具体步骤和输入输出:

输入:

batches:包含多个样本的列表,每个样本由一个输入字典和一个输出标签组成。
输出:

一个元组,包含两个元素:
一个字典,包含以下键值对:
encoder_cat:编码器输入的类别特征,形状为(batch_size, max_encoder_length, num_cat_features)。
encoder_cont:编码器输入的连续特征,形状为(batch_size, max_encoder_length, num_cont_features)。
encoder_target:编码器的目标,形状为(batch_size, max_encoder_length)。
encoder_lengths:编码器每个样本的有效长度,形状为(batch_size,)。
decoder_cat:解码器输入的类别特征,形状为(batch_size, max_decoder_length, num_cat_features)。
decoder_cont:解码器输入的连续特征,形状为(batch_size, max_decoder_length, num_cont_features)。
decoder_target:解码器的目标,形状为(batch_size, max_decoder_length)或(batch_size, num_decoder_outputs, max_decoder_length)。
decoder_lengths:解码器每个样本的有效长度,形状为(batch_size,)。
decoder_time_idx:解码器每个时间步的索引,形状为(batch_size, max_decoder_length)。
groups:样本所属的分组,形状为(batch_size,)。
target_scale:目标的缩放因子,形状为(batch_size,)或(batch_size, num_decoder_outputs)。
一个元组,包含两个元素:
decoder_target或decoder_target的列表,与上面字典中的decoder_target相对应。
权重,形状与decoder_target相同,或者为None。
函数具体步骤:

从batches列表中提取每个样本的编码器长度(encoder_lengths)和解码器长度(decoder_lengths)。
根据encoder_lengths和解码器的起始时间索引(encoder_time_idx_start),计算decoder_time_idx。其中decoder_time_idx_start是从encoder_time_idx_start + encoder_lengths计算得出的。
提取每个样本的分组(groups)。
对编码器输入的类别特征(encoder_cat)和连续特征(encoder_cont)进行padding。
对解码器输入的类别特征(decoder_cat)和连续特征(decoder_cont)进行padding。
提取每个样本的目标(encoder_target)和权重(weight),并对它们进行padding。
对目标的缩放因子(target_scale)进行处理,如果是张量,则进行stack操作,否则将其转换为张量。
最后,函数返回上述字典和元组。

    def to_dataloader(
        self, train: bool = True, batch_size: int = 64, batch_sampler: Union[Sampler, str] = None, **kwargs
    ) -> DataLoader:
        """
        Get dataloader from dataset.

        The

        Args:
            train (bool, optional): if dataloader is used for training or prediction
                Will shuffle and drop last batch if True. Defaults to True.
            batch_size (int): batch size for training model. Defaults to 64.
            batch_sampler (Union[Sampler, str]): batch sampler or string. One of

                * "synchronized": ensure that samples in decoder are aligned in time. Does not support missing
                  values in dataset. This makes only sense if the underlying algorithm makes use of values aligned
                  in time.
                * PyTorch Sampler instance: any PyTorch sampler, e.g. the WeightedRandomSampler()
                * None: samples are taken randomly from times series.

            **kwargs: additional arguments to ``DataLoader()``

        Returns:
            DataLoader: dataloader that returns Tuple.
                First entry is ``x``, a dictionary of tensors with the entries (and shapes in brackets)

                * encoder_cat (batch_size x n_encoder_time_steps x n_features): long tensor of encoded
                  categoricals for encoder
                * encoder_cont (batch_size x n_encoder_time_steps x n_features): float tensor of scaled continuous
                  variables for encoder
                * encoder_target (batch_size x n_encoder_time_steps or list thereof with each entry for a different
                  target):
                  float tensor with unscaled continous target or encoded categorical target,
                  list of tensors for multiple targets
                * encoder_lengths (batch_size): long tensor with lengths of the encoder time series. No entry will
                  be greater than n_encoder_time_steps
                * decoder_cat (batch_size x n_decoder_time_steps x n_features): long tensor of encoded
                  categoricals for decoder
                * decoder_cont (batch_size x n_decoder_time_steps x n_features): float tensor of scaled continuous
                  variables for decoder
                * decoder_target (batch_size x n_decoder_time_steps or list thereof with each entry for a different
                  target):
                  float tensor with unscaled continous target or encoded categorical target for decoder
                  - this corresponds to first entry of ``y``, list of tensors for multiple targets
                * decoder_lengths (batch_size): long tensor with lengths of the decoder time series. No entry will
                  be greater than n_decoder_time_steps
                * group_ids (batch_size x number_of_ids): encoded group ids that identify a time series in the dataset
                * target_scale (batch_size x scale_size or list thereof with each entry for a different target):
                  parameters used to normalize the target.
                  Typically these are mean and standard deviation. Is list of tensors for multiple targets.


                Second entry is ``y``, a tuple of the form (``target``, `weight`)

                * target (batch_size x n_decoder_time_steps or list thereof with each entry for a different target):
                  unscaled (continuous) or encoded (categories) targets, list of tensors for multiple targets
                * weight (None or batch_size x n_decoder_time_steps): weight

        Example:

            Weight by samples for training:

            .. code-block:: python

                from torch.utils.data import WeightedRandomSampler

                # length of probabilties for sampler have to be equal to the length of the index
                probabilities = np.sqrt(1 + data.loc[dataset.index, "target"])
                sampler = WeightedRandomSampler(probabilities, len(probabilities))
                dataset.to_dataloader(train=True, sampler=sampler, shuffle=False)
        """
        default_kwargs = dict(
            shuffle=train,
            drop_last=train and len(self) > batch_size,
            collate_fn=self._collate_fn,
            batch_size=batch_size,
            batch_sampler=batch_sampler,
        )
        default_kwargs.update(kwargs)
        kwargs = default_kwargs
        if kwargs["batch_sampler"] is not None:
            sampler = kwargs["batch_sampler"]
            if isinstance(sampler, str):
                if sampler == "synchronized":
                    kwargs["batch_sampler"] = TimeSynchronizedBatchSampler(
                        self, batch_size=kwargs["batch_size"], shuffle=kwargs["shuffle"], drop_last=kwargs["drop_last"]
                    )
                else:
                    raise ValueError(f"batch_sampler {sampler} unknown - see docstring for valid batch_sampler")
            del kwargs["batch_size"]
            del kwargs["shuffle"]
            del kwargs["drop_last"]

        return DataLoader(
            self,
            **kwargs,
        )

    def x_to_index(self, x: Dict[str, torch.Tensor]) -> pd.DataFrame:
        """
        Decode dataframe index from x.

        Returns:
            dataframe with time index column for first prediction and group ids
        """
        index_data = {self.time_idx: x["decoder_time_idx"][:, 0].cpu()}
        for id in self.group_ids:
            index_data[id] = x["groups"][:, self.group_ids.index(id)].cpu()
            # decode if possible
            index_data[id] = self.transform_values(id, index_data[id], inverse=True, group_id=True)
        index = pd.DataFrame(index_data)
        return index

    def __repr__(self) -> str:
        return repr_class(self, attributes=self.get_parameters(), extra_attributes=dict(length=len(self)))

这是一个Python函数,名为 to_dataloader。下面是逐行解释:

def to_dataloader(
        self, train: bool = True, batch_size: int = 64, batch_sampler: Union[Sampler, str] = None, **kwargs
    ) -> DataLoader:

这个函数将返回一个PyTorch的DataLoader对象,用于对数据进行迭代。它有三个参数,一个可选的布尔型 train(用于指定是否用于训练),一个整型 batch_size(指定批处理大小,默认为64)和一个可选的 batch_sampler,类型可以是 Sampler 或 str(指定如何采样数据)。另外,这个函数可以接受任意数量的关键字参数,它们将被传递给 DataLoader()。

default_kwargs = dict(
            shuffle=train,
            drop_last=train and len(self) > batch_size,
            collate_fn=self._collate_fn,
            batch_size=batch_size,
            batch_sampler=batch_sampler,
        )
        default_kwargs.update(kwargs)
        kwargs = default_kwargs

这个函数首先定义了一个默认参数的字典,然后将其更新为传入的任何关键字参数。最后,将 kwargs 变量设置为更新后的字典。

if kwargs["batch_sampler"] is not None:
            sampler = kwargs["batch_sampler"]
            if isinstance(sampler, str):
                if sampler == "synchronized":
                    kwargs["batch_sampler"] = TimeSynchronizedBatchSampler(
                        self, batch_size=kwargs["batch_size"], shuffle=kwargs["shuffle"], drop_last=kwargs["drop_last"]
                    )
                else:
                    raise ValueError(f"batch_sampler {sampler} unknown - see docstring for valid batch_sampler")
            del kwargs["batch_size"]
            del kwargs["shuffle"]
            del kwargs["drop_last"]

如果 batch_sampler 不是 None,那么检查它是否是字符串类型。如果是,则根据字符串的值创建一个 TimeSynchronizedBatchSampler 对象。否则,抛出一个 ValueError 异常。然后删除不再需要的三个键-值对:batch_size、shuffle 和 drop_last。

return DataLoader(
            self,
            **kwargs,
        )

最后,这个函数返回一个 DataLoader 对象,它以 self(该函数所在的对象)作为第一个参数,以 kwargs 中的所有关键字参数作为其余参数。

除了 to_dataloader 函数,还有两个辅助函数。第一个是 x_to_index,它将 x 中的时间索引和组 ID 解码为 Pandas DataFrame。第二个是 repr,它返回该对象的字符串表示,包括对象的参数和长度。

  • 7
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值