源码:https://github.com/Time-MoE/Time-MoE
这段代码实现了两个用于时间序列数据处理的窗口化数据集类,主要用于将长序列切割成固定长度的子序列,为模型训练提供合适的输入格式。
1. 核心类:TimeMoEWindowDataset
1.1 功能概述
将长时间序列转换为固定长度的非重叠滑动窗口,每个窗口包含:
- 输入序列(
input_ids
):长度为context_length
- 标签序列(
labels
):长度为context_length + prediction_length
,与输入序列错位 1 个时间步 - 损失掩码(
loss_masks
):标记哪些位置需要计算损失
1.2 关键参数
context_length
:输入序列长度(历史信息)prediction_length
:预测序列长度(未来信息,默认为 0)stride
:窗口滑动步长(默认为窗口大小,即非重叠)
1.3 初始化逻辑
def __init__(self, dataset, context_length, prediction_length=0, stride=None):
self.dataset = dataset
self.context_length = context_length
self.prediction_length = prediction_length
self.window_size = context_length + prediction_length
self.stride = stride or self.window_size # 默认非重叠
# 构建子序列索引列表
self.sub_seq_indexes = []
for seq_idx in range(len(dataset)):
n_points = dataset.get_sequence_length_by_idx(seq_idx)
if n_points < 2:
continue
# 添加初始窗口
self.sub_seq_indexes.append((seq_idx, 0))
# 添加后续窗口(按stride滑动)
for offset in range(self.stride, n_points - self.window_size - 1 + 1, self.stride):
self.sub_seq_indexes.append((seq_idx, offset))
1.4数据获取逻辑
def __getitem__(self, seq_idx):
seq_i, offset_i = self.sub_seq_indexes[seq_idx]
# 提取窗口数据(包含额外1个点用于错位)
seq = self.dataset[seq_i][offset_i: offset_i + self.window_size + 1]
seq = np.array(seq, dtype=np.float32)
# 创建损失掩码(标记有效位置)
loss_mask = np.ones(len(seq) - 1, dtype=np.int32)
# 处理序列长度不足的情况(填充0)
n_pad = self.window_size + 1 - len(seq)
if n_pad > 0:
seq = np.pad(seq, (0, n_pad), 'constant', constant_values=0)
loss_mask = np.pad(loss_mask, (0, n_pad), 'constant', constant_values=0)
return {
'input_ids': seq[:-1], # 输入序列
'labels': seq[1:], # 标签序列(错位1步)
'loss_masks': loss_mask # 损失掩码(忽略填充位置)
}
2.增强类:UniversalTimeMoEWindowDataset
2.1 功能概述
实现了一种打包技术(pack technique),将多个短序列合并成一个固定长度的窗口,提高数据利用率和训练效率。
2.2 关键参数
shuffle
:是否随机打乱序列顺序(默认为 False)- 其他参数与
TimeMoEWindowDataset
类似
2.3 初始化逻辑
def __init__(self, dataset, context_length, prediction_length=0, shuffle=False):
self.dataset = dataset
self.window_size = context_length + prediction_length
self.window_info_list = [] # 存储窗口信息(每个窗口包含多个子序列片段)
cur_window_info = [] # 当前窗口的子序列片段
num_cur_remaining_points = self.window_size # 当前窗口剩余可用长度
# 遍历所有序列(可随机打乱)
iterator = range(len(dataset))
if shuffle:
iterator = list(iterator)
random.shuffle(iterator)
for seq_idx in iterator:
seq_len = dataset.get_sequence_length_by_idx(seq_idx)
remaining_seq_len = seq_len
# 将当前序列切割成多个片段,填充到窗口中
while remaining_seq_len > 0:
if remaining_seq_len < num_cur_remaining_points:
# 当前序列剩余部分不足以填满窗口,全部加入
cur_window_info.append((seq_idx, seq_len - remaining_seq_len, remaining_seq_len))
num_cur_remaining_points -= remaining_seq_len
remaining_seq_len = 0
else:
# 当前序列剩余部分可以填满窗口,截取部分加入
cur_window_info.append((seq_idx, seq_len - remaining_seq_len, num_cur_remaining_points))
remaining_seq_len -= num_cur_remaining_points
# 当前窗口已满,添加到结果列表并重置
self.window_info_list.append(cur_window_info)
num_cur_remaining_points = self.window_size
cur_window_info = []
2.4 数据获取逻辑
def __getitem__(self, window_idx):
window_info = self.window_info_list[window_idx]
seq = []
# 从多个子序列片段构建完整窗口
for seq_idx, start_idx, offset in window_info:
part_seq = dataset[seq_idx][start_idx: start_idx + offset]
seq.append(part_seq)
# 合并所有片段
if len(seq) == 1:
seq = np.array(seq[0], dtype=np.float32)
else:
seq = np.concatenate(seq, axis=0, dtype=np.float32)
return {
'input_ids': seq[:-1], # 输入序列
'labels': seq[1:], # 标签序列(错位1步)
}
3.对比分析
-
TimeMoEWindowDataset
:- 适用于长序列数据
- 适合需要严格控制窗口独立性的场景
- 当
stride < window_size
时支持重叠窗口,用于增强数据多样性
-
UniversalTimeMoEWindowDataset
:- 适用于大量短序列数据
- 通过打包技术减少填充,提高训练效率
- 适合自回归模型(如 GPT 类模型),允许不同序列之间的信息流动
性能与内存权衡:
-
TimeMoEWindowDataset
:- 预计算所有窗口索引,内存开销较高(尤其是长序列)
- 数据访问速度快(直接索引)
-
UniversalTimeMoEWindowDataset
:- 动态构建窗口,内存开销低
- 数据访问时需要拼接多个片段,计算开销略高
4.总结
这两个类通过不同策略解决了时间序列数据的窗口化问题:
TimeMoEWindowDataset
:提供简单直观的滑动窗口实现,支持灵活的重叠策略UniversalTimeMoEWindowDataset
:通过序列打包技术优化短序列处理,提高训练效率
两者共同构成了一个完整的时间序列数据预处理工具链,为后续模型训练提供了标准化的输入格式。