TIME - MoE 模型代码 3.3——Time-MoE-main/time_moe/datasets/time_moe_window_dataset.py

源码:https://github.com/Time-MoE/Time-MoE

这段代码实现了两个用于时间序列数据处理的窗口化数据集类,主要用于将长序列切割成固定长度的子序列,为模型训练提供合适的输入格式。


1. 核心类:TimeMoEWindowDataset

1.1 功能概述

将长时间序列转换为固定长度的非重叠滑动窗口,每个窗口包含:

  • 输入序列input_ids):长度为context_length
  • 标签序列labels):长度为context_length + prediction_length,与输入序列错位 1 个时间步
  • 损失掩码loss_masks):标记哪些位置需要计算损失

1.2 关键参数

  • context_length:输入序列长度(历史信息)
  • prediction_length:预测序列长度(未来信息,默认为 0)
  • stride:窗口滑动步长(默认为窗口大小,即非重叠)

1.3 初始化逻辑

def __init__(self, dataset, context_length, prediction_length=0, stride=None):
    self.dataset = dataset
    self.context_length = context_length
    self.prediction_length = prediction_length
    self.window_size = context_length + prediction_length
    self.stride = stride or self.window_size  # 默认非重叠
    
    # 构建子序列索引列表
    self.sub_seq_indexes = []
    for seq_idx in range(len(dataset)):
        n_points = dataset.get_sequence_length_by_idx(seq_idx)
        if n_points < 2:
            continue
        # 添加初始窗口
        self.sub_seq_indexes.append((seq_idx, 0))
        # 添加后续窗口(按stride滑动)
        for offset in range(self.stride, n_points - self.window_size - 1 + 1, self.stride):
            self.sub_seq_indexes.append((seq_idx, offset))

1.4数据获取逻辑

def __getitem__(self, seq_idx):
    seq_i, offset_i = self.sub_seq_indexes[seq_idx]
    # 提取窗口数据(包含额外1个点用于错位)
    seq = self.dataset[seq_i][offset_i: offset_i + self.window_size + 1]
    seq = np.array(seq, dtype=np.float32)
    
    # 创建损失掩码(标记有效位置)
    loss_mask = np.ones(len(seq) - 1, dtype=np.int32)
    # 处理序列长度不足的情况(填充0)
    n_pad = self.window_size + 1 - len(seq)
    if n_pad > 0:
        seq = np.pad(seq, (0, n_pad), 'constant', constant_values=0)
        loss_mask = np.pad(loss_mask, (0, n_pad), 'constant', constant_values=0)
    
    return {
        'input_ids': seq[:-1],      # 输入序列
        'labels': seq[1:],         # 标签序列(错位1步)
        'loss_masks': loss_mask    # 损失掩码(忽略填充位置)
    }

2.增强类:UniversalTimeMoEWindowDataset

2.1 功能概述

实现了一种打包技术(pack technique),将多个短序列合并成一个固定长度的窗口,提高数据利用率和训练效率。

2.2 关键参数

  • shuffle:是否随机打乱序列顺序(默认为 False)
  • 其他参数与TimeMoEWindowDataset类似

2.3 初始化逻辑

def __init__(self, dataset, context_length, prediction_length=0, shuffle=False):
    self.dataset = dataset
    self.window_size = context_length + prediction_length
    
    self.window_info_list = []  # 存储窗口信息(每个窗口包含多个子序列片段)
    cur_window_info = []        # 当前窗口的子序列片段
    num_cur_remaining_points = self.window_size  # 当前窗口剩余可用长度
    
    # 遍历所有序列(可随机打乱)
    iterator = range(len(dataset))
    if shuffle:
        iterator = list(iterator)
        random.shuffle(iterator)
    
    for seq_idx in iterator:
        seq_len = dataset.get_sequence_length_by_idx(seq_idx)
        remaining_seq_len = seq_len
        
        # 将当前序列切割成多个片段,填充到窗口中
        while remaining_seq_len > 0:
            if remaining_seq_len < num_cur_remaining_points:
                # 当前序列剩余部分不足以填满窗口,全部加入
                cur_window_info.append((seq_idx, seq_len - remaining_seq_len, remaining_seq_len))
                num_cur_remaining_points -= remaining_seq_len
                remaining_seq_len = 0
            else:
                # 当前序列剩余部分可以填满窗口,截取部分加入
                cur_window_info.append((seq_idx, seq_len - remaining_seq_len, num_cur_remaining_points))
                remaining_seq_len -= num_cur_remaining_points
                # 当前窗口已满,添加到结果列表并重置
                self.window_info_list.append(cur_window_info)
                num_cur_remaining_points = self.window_size
                cur_window_info = []

2.4 数据获取逻辑

def __getitem__(self, window_idx):
    window_info = self.window_info_list[window_idx]
    seq = []
    # 从多个子序列片段构建完整窗口
    for seq_idx, start_idx, offset in window_info:
        part_seq = dataset[seq_idx][start_idx: start_idx + offset]
        seq.append(part_seq)
    
    # 合并所有片段
    if len(seq) == 1:
        seq = np.array(seq[0], dtype=np.float32)
    else:
        seq = np.concatenate(seq, axis=0, dtype=np.float32)
    
    return {
        'input_ids': seq[:-1],  # 输入序列
        'labels': seq[1:],     # 标签序列(错位1步)
    }

3.对比分析

  1. TimeMoEWindowDataset

    • 适用于长序列数据
    • 适合需要严格控制窗口独立性的场景
    • stride < window_size时支持重叠窗口,用于增强数据多样性
  2. UniversalTimeMoEWindowDataset

    • 适用于大量短序列数据
    • 通过打包技术减少填充,提高训练效率
    • 适合自回归模型(如 GPT 类模型),允许不同序列之间的信息流动

性能与内存权衡:

  • TimeMoEWindowDataset

    • 预计算所有窗口索引,内存开销较高(尤其是长序列)
    • 数据访问速度快(直接索引)
  • UniversalTimeMoEWindowDataset

    • 动态构建窗口,内存开销低
    • 数据访问时需要拼接多个片段,计算开销略高

4.总结

这两个类通过不同策略解决了时间序列数据的窗口化问题:

  • TimeMoEWindowDataset:提供简单直观的滑动窗口实现,支持灵活的重叠策略
  • UniversalTimeMoEWindowDataset:通过序列打包技术优化短序列处理,提高训练效率

两者共同构成了一个完整的时间序列数据预处理工具链,为后续模型训练提供了标准化的输入格式。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值