# 导入必要的库
from benchpots.datasets import preprocess_physionet2012 # 导入physionet2012数据集预处理函数
import numpy as np
import os
import torch
from pypots.data.saving import pickle_dump
# 设置随机种子以保证实验可复现
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed(SEED)
# 加载并预处理physionet2012数据集
physionet2012_dataset = preprocess_physionet2012(
subset="set-a",
pattern="point",
rate=0.1,
)
# 打印数据集的键
print(physionet2012_dataset.keys())
# 创建测试集的缺失值指示掩码
test_X_nan = np.isnan(physionet2012_dataset["test_X"])
test_X_ori_nan = np.isnan(physionet2012_dataset["test_X_ori"])
physionet2012_dataset["test_X_indicating_mask"] = test_X_nan ^ test_X_ori_nan
# 将原始测试数据中的NaN值替换为0
physionet2012_dataset["test_X_ori"] = np.nan_to_num(physionet2012_dataset["test_X_ori"])
# 构建数据集字典
train_set = {"X": physionet2012_dataset["train_X"]}
val_set = {
"X": physionet2012_dataset["val_X"],
"X_ori": physionet2012_dataset["val_X_ori"]
}
test_set = {
"X": physionet2012_dataset["test_X"],
"X_ori": physionet2012_dataset["test_X_ori"]
}
# 导入SAITS模型及相关组件
from pypots.imputation import SAITS
from pypots.optim import Adam
from pypots.nn.functional import calc_mse
# 自动选择设备
device = "cuda" if torch.cuda.is_available() else "cpu"
# 初始化SAITS模型
saits = SAITS(
n_steps=physionet2012_dataset['n_steps'],
n_features=physionet2012_dataset['n_features'],
n_layers=3,
d_model=64,
n_heads=4,
d_k=16,
d_v=16,
d_ffn=128,
dropout=0.1,
ORT_weight=1,
MIT_weight=1,
batch_size=32,
epochs=10,
patience=3,
optimizer=Adam(lr=1e-3),
num_workers=0,
device=device,
saving_path="result_saving/imputation/saits",
model_saving_strategy="best",
)
# 确保保存路径存在
os.makedirs("result_saving/imputation/saits", exist_ok=True)
os.makedirs("result_saving", exist_ok=True)
# 训练模型
try:
saits.fit(train_set, val_set)
except Exception as e:
print(f"训练过程中发生错误: {e}")
exit(1)
# 使用训练好的模型进行预测
try:
test_set_imputation_results = saits.predict(test_set)
except Exception as e:
print(f"预测过程中发生错误: {e}")
exit(1)
# 计算MSE
try:
test_MSE = calc_mse(
test_set_imputation_results["imputation"],
physionet2012_dataset["test_X_ori"],
physionet2012_dataset["test_X_indicating_mask"],
)
print(f"SAITS test_MSE: {test_MSE}")
except Exception as e:
print(f"MSE计算过程中发生错误: {e}")
# 进行插补并保存结果
try:
train_set_imputation = saits.impute(train_set)
val_set_imputation = saits.impute(val_set)
test_set_imputation = test_set_imputation_results["imputation"]
dict_to_save = {
'train_set_imputation': train_set_imputation,
'train_set_labels': physionet2012_dataset['train_y'],
'val_set_imputation': val_set_imputation,
'val_set_labels': physionet2012_dataset['val_y'],
'test_set_imputation': test_set_imputation,
'test_set_labels': physionet2012_dataset['test_y'],
}
pickle_dump(dict_to_save, "result_saving/imputed_physionet2012.pkl")
except Exception as e:
print(f"保存过程中发生错误: {e}")
代码功能概述
以上代码实现了一个时间序列数据缺失值填补的完整流程,使用了SAITS深度学习模型来处理PhysioNet2012数据集中的缺失数据。
详细步骤解释
-
导入库和设置随机种子
- 导入NumPy、PyTorch等基础库
- 设置随机种子(42)保证实验可重复性
-
数据预处理
- 加载PhysioNet2012数据集(时间序列数据)
- 创建缺失值指示掩码(标记哪些位置是缺失的)
- 将NaN值替换为0作为临时处理
- 划分训练集、验证集和测试集
-
模型构建与训练
- 初始化SAITS模型(一种专门处理时间序列缺失值的深度学习模型)
- 设置模型参数(层数、注意力头数、学习率等)
- 训练模型并在验证集上评估
-
评估与保存结果
- 在测试集上进行预测
- 计算MSE(均方误差)评估模型性能
- 保存填补后的数据和模型
核心知识点总结
1. 数据预处理
- 缺失值处理:使用
np.isnan()
检测缺失值,np.nan_to_num()
替换缺失值 - 数据划分:分为训练集、验证集和测试集
- 掩码创建:用异或操作(^)创建指示哪些位置是人为缺失的掩码
2. SAITS模型
- 模型类型:基于Transformer的时间序列缺失值填补模型
- 关键参数:
n_steps
:时间步长(序列长度)n_features
:特征数量n_heads
:注意力头数d_model
:模型维度
什么是 SAITS 模型?
SAITS 是一种专门用于处理时间序列数据中的缺失值的深度学习模型。它的全称是 Self-Attention-based Imputation for Time Series,翻译成中文就是“基于自注意力机制的时间序列插补”。
时间序列数据和缺失值
首先,我们需要了解一下什么是时间序列数据和缺失值:
- 时间序列数据:这是一种按照时间顺序排列的数据,比如股票价格、天气温度、心率监测等。这些数据通常是一系列按时间点记录的数值。
- 缺失值:在实际收集数据的过程中,由于各种原因(如设备故障、网络问题等),某些时间点的数据可能会丢失,这就是缺失值。
为什么需要 SAITS?
处理缺失值很重要,因为如果直接忽略这些缺失值,可能会导致数据分析结果不准确。SAITS 模型就是为了更准确地填补这些缺失值而设计的。
SAITS 模型的工作原理
SAITS 模型的核心思想是利用自注意力机制(Self-Attention Mechanism)来捕捉时间序列数据中的长期依赖关系,从而更准确地预测缺失值。下面我们来一步步解释:
自注意力机制:
自注意力机制是一种可以让模型在处理每个时间点的数据时,考虑其他时间点的数据的重要性的方法。简单来说,它可以帮助模型“关注”到那些对当前时间点最有帮助的信息。举个例子,如果你正在预测某个时间段的气温,自注意力机制可以帮助模型注意到前几个小时或几天的气温变化,而不是仅仅依赖最近的一个时间点。
时间序列插补:
插补(Imputation)是指用某种方法填补缺失值的过程。SAITS 模型通过训练,学习如何根据已有的数据来预测缺失的数据。模型会通过多次迭代,逐渐优化其预测能力,使得填补的缺失值尽可能接近真实值。
多头自注意力机制:
SAITS 模型使用了多头自注意力机制,这意味着模型可以从多个不同的角度来捕捉时间序列数据中的特征。这样可以更全面地理解数据,提高预测的准确性。
具体步骤
输入数据:
将包含缺失值的时间序列数据输入到模型中。
编码阶段:
模型通过自注意力机制对每个时间点的数据进行编码,生成一个表示该时间点的向量。这个过程中,模型会考虑其他时间点的数据,特别是那些对当前时间点影响较大的数据。
插补阶段:
模型根据编码后的向量,预测缺失值。这个过程是通过神经网络的多层计算完成的,最终输出填补后的完整时间序列数据。
输出结果:
模型输出填补了缺失值的时间序列数据,可以用于后续的分析或预测任务。
优点
- 高精度:通过自注意力机制,SAITS 模型可以捕捉到时间序列数据中的长期依赖关系,从而更准确地填补缺失值。
- 鲁棒性强:即使在数据缺失较多的情况下,SAITS 模型也能表现出较好的性能。
- 灵活性:可以应用于多种类型的时间序列数据,如金融数据、医疗数据等。
总结
SAITS 模型是一种基于自注意力机制的深度学习模型,专门用于处理时间序列数据中的缺失值。它通过捕捉数据中的长期依赖关系,更准确地预测和填补缺失值,从而提高数据分析的准确性。希望这个解释对你有所帮助!
3. 训练与评估
- 优化器:使用Adam优化器,学习率1e-3
- 评估指标:MSE(均方误差)
- 早停机制:patience=3(连续3轮验证集性能不提升则停止训练)
4. 工程实践
- 设备选择:自动检测使用CPU或GPU
- 错误处理:使用try-except捕获并处理可能出现的错误
- 结果保存:使用pickle保存填补后的数据
学习建议
- 先理解数据:PhysioNet2012是时间序列数据,了解数据特点有助于理解代码
- 分模块学习:先掌握数据预处理,再学习模型构建,最后理解训练评估流程
- 实践练习:尝试修改模型参数(如层数、学习率)观察对结果的影响
- 扩展阅读:了解Transformer架构和注意力机制,这是SAITS模型的基础
这段代码展示了一个完整的时间序列数据缺失值填补流程,从数据加载到模型训练评估,适合学习如何处理现实世界中的不完整时间序列数据。
以下是SAITS模型训练的具体步骤及关键要点解析:
一、数据准备阶段
-
数据集加载与划分
- 使用
preprocess_physionet2012
加载医疗时间序列数据,划分为训练集、验证集和测试集4 - 对缺失值生成指示掩码(
indicating_mask
),标记人为缺失的位置4
- 使用
-
缺失值预处理
- 将原始NaN值替换为0作为临时填充
- 通过
np.isnan()
检测真实缺失值,异或运算区分自然缺失与人工掩码4
二、模型初始化
-
参数配置
- 设置时间步长(
n_steps
)和特征维度(n_features
) - 定义Transformer架构参数:层数(
n_layers=3
)、注意力头数(n_heads=4
)、模型维度(d_model=64
)4 - 配置双任务权重(
ORT_weight=1
,MIT_weight=1
)4
- 设置时间步长(
-
硬件适配
- 自动检测GPU/CPU设备(
torch.cuda.is_available()
) - 设置批处理大小(
batch_size=32
)和并行工作线程数(num_workers=0
)5
- 自动检测GPU/CPU设备(
三、训练流程
-
双任务协同训练
- MIT(掩码插补任务):随机掩码20%已知值,让模型预测被隐藏的真实值4
- ORT(观测重建任务):强制模型保持已观测值的准确性,防止填补干扰原始数据4
- 两任务损失加权求和作为总损失函数4
-
优化过程
- 使用Adam优化器(
lr=1e-3
)进行参数更新4 - 采用早停机制(
patience=3
)防止过拟合4 - 每轮验证集评估,保存最佳模型(
model_saving_strategy="best"
)4
- 使用Adam优化器(
四、评估与部署
-
性能验证
- 在测试集计算MSE指标,对比填补值与真实值差异4
- 分析不同缺失率下的插补准确率4
-
关键组件验证
- 消融实验证明:双DMSA模块、加权组合层对性能提升至关重要4
- MIT和ORT任务缺一不可,单独使用任一任务会导致性能下降4
五、注意事项
-
数据特性适配
- 时间序列通常具有高缺失率(PhysioNet2012平均缺失率约80%)
- 需调整掩码比例匹配实际数据缺失模式4
-
工程实践
- 使用
pickle_dump
保存预处理数据 - 通过
os.makedirs
确保结果目录存在5
- 使用
该流程通过Transformer架构实现时空特征捕获,结合双任务机制平衡插补准确性与观测值保真度,适用于医疗监测、传感器数据等时序缺失场景。