SAITS模型在PhysioNet2012数据集上的缺失值填补实践笔记

# 导入必要的库
from benchpots.datasets import preprocess_physionet2012  # 导入physionet2012数据集预处理函数
import numpy as np
import os
import torch
from pypots.data.saving import pickle_dump

# 设置随机种子以保证实验可复现
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

# 加载并预处理physionet2012数据集
physionet2012_dataset = preprocess_physionet2012(
    subset="set-a",
    pattern="point",
    rate=0.1,
)

# 打印数据集的键
print(physionet2012_dataset.keys())

# 创建测试集的缺失值指示掩码
test_X_nan = np.isnan(physionet2012_dataset["test_X"])
test_X_ori_nan = np.isnan(physionet2012_dataset["test_X_ori"])
physionet2012_dataset["test_X_indicating_mask"] = test_X_nan ^ test_X_ori_nan

# 将原始测试数据中的NaN值替换为0
physionet2012_dataset["test_X_ori"] = np.nan_to_num(physionet2012_dataset["test_X_ori"])

# 构建数据集字典
train_set = {"X": physionet2012_dataset["train_X"]}
val_set = {
    "X": physionet2012_dataset["val_X"],
    "X_ori": physionet2012_dataset["val_X_ori"]
}
test_set = {
    "X": physionet2012_dataset["test_X"],
    "X_ori": physionet2012_dataset["test_X_ori"]
}

# 导入SAITS模型及相关组件
from pypots.imputation import SAITS
from pypots.optim import Adam
from pypots.nn.functional import calc_mse

# 自动选择设备
device = "cuda" if torch.cuda.is_available() else "cpu"

# 初始化SAITS模型
saits = SAITS(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    n_layers=3,
    d_model=64,
    n_heads=4,
    d_k=16,
    d_v=16,
    d_ffn=128,
    dropout=0.1,
    ORT_weight=1,
    MIT_weight=1,
    batch_size=32,
    epochs=10,
    patience=3,
    optimizer=Adam(lr=1e-3),
    num_workers=0,
    device=device,
    saving_path="result_saving/imputation/saits",
    model_saving_strategy="best",
)

# 确保保存路径存在
os.makedirs("result_saving/imputation/saits", exist_ok=True)
os.makedirs("result_saving", exist_ok=True)

# 训练模型
try:
    saits.fit(train_set, val_set)
except Exception as e:
    print(f"训练过程中发生错误: {e}")
    exit(1)

# 使用训练好的模型进行预测
try:
    test_set_imputation_results = saits.predict(test_set)
except Exception as e:
    print(f"预测过程中发生错误: {e}")
    exit(1)

# 计算MSE
try:
    test_MSE = calc_mse(
        test_set_imputation_results["imputation"],
        physionet2012_dataset["test_X_ori"],
        physionet2012_dataset["test_X_indicating_mask"],
    )
    print(f"SAITS test_MSE: {test_MSE}")
except Exception as e:
    print(f"MSE计算过程中发生错误: {e}")

# 进行插补并保存结果
try:
    train_set_imputation = saits.impute(train_set)
    val_set_imputation = saits.impute(val_set)
    test_set_imputation = test_set_imputation_results["imputation"]

    dict_to_save = {
        'train_set_imputation': train_set_imputation,
        'train_set_labels': physionet2012_dataset['train_y'],
        'val_set_imputation': val_set_imputation,
        'val_set_labels': physionet2012_dataset['val_y'],
        'test_set_imputation': test_set_imputation,
        'test_set_labels': physionet2012_dataset['test_y'],
    }

    pickle_dump(dict_to_save, "result_saving/imputed_physionet2012.pkl")
except Exception as e:
    print(f"保存过程中发生错误: {e}")

代码功能概述

以上代码实现了一个时间序列数据缺失值填补的完整流程,使用了SAITS深度学习模型来处理PhysioNet2012数据集中的缺失数据。

详细步骤解释

  1. 导入库和设置随机种子

    • 导入NumPy、PyTorch等基础库
    • 设置随机种子(42)保证实验可重复性
  2. 数据预处理

    • 加载PhysioNet2012数据集(时间序列数据)
    • 创建缺失值指示掩码(标记哪些位置是缺失的)
    • 将NaN值替换为0作为临时处理
    • 划分训练集、验证集和测试集
  3. 模型构建与训练

    • 初始化SAITS模型(一种专门处理时间序列缺失值的深度学习模型)
    • 设置模型参数(层数、注意力头数、学习率等)
    • 训练模型并在验证集上评估
  4. 评估与保存结果

    • 在测试集上进行预测
    • 计算MSE(均方误差)评估模型性能
    • 保存填补后的数据和模型

核心知识点总结

1. 数据预处理

  • 缺失值处理‌:使用np.isnan()检测缺失值,np.nan_to_num()替换缺失值
  • 数据划分‌:分为训练集、验证集和测试集
  • 掩码创建‌:用异或操作(^)创建指示哪些位置是人为缺失的掩码

2. SAITS模型

  • 模型类型‌:基于Transformer的时间序列缺失值填补模型
  • 关键参数‌:
    • n_steps:时间步长(序列长度)
    • n_features:特征数量
    • n_heads:注意力头数
    • d_model:模型维度

什么是 SAITS 模型?

SAITS 是一种专门用于处理时间序列数据中的缺失值的深度学习模型。它的全称是 Self-Attention-based Imputation for Time Series,翻译成中文就是“基于自注意力机制的时间序列插补”。

时间序列数据和缺失值

首先,我们需要了解一下什么是时间序列数据和缺失值:

  • 时间序列数据:这是一种按照时间顺序排列的数据,比如股票价格、天气温度、心率监测等。这些数据通常是一系列按时间点记录的数值。
  • 缺失值:在实际收集数据的过程中,由于各种原因(如设备故障、网络问题等),某些时间点的数据可能会丢失,这就是缺失值。

为什么需要 SAITS?

处理缺失值很重要,因为如果直接忽略这些缺失值,可能会导致数据分析结果不准确。SAITS 模型就是为了更准确地填补这些缺失值而设计的。

SAITS 模型的工作原理

SAITS 模型的核心思想是利用自注意力机制(Self-Attention Mechanism)来捕捉时间序列数据中的长期依赖关系,从而更准确地预测缺失值。下面我们来一步步解释:

自注意力机制

自注意力机制是一种可以让模型在处理每个时间点的数据时,考虑其他时间点的数据的重要性的方法。简单来说,它可以帮助模型“关注”到那些对当前时间点最有帮助的信息。举个例子,如果你正在预测某个时间段的气温,自注意力机制可以帮助模型注意到前几个小时或几天的气温变化,而不是仅仅依赖最近的一个时间点。

时间序列插补

插补(Imputation)是指用某种方法填补缺失值的过程。SAITS 模型通过训练,学习如何根据已有的数据来预测缺失的数据。模型会通过多次迭代,逐渐优化其预测能力,使得填补的缺失值尽可能接近真实值。

多头自注意力机制

SAITS 模型使用了多头自注意力机制,这意味着模型可以从多个不同的角度来捕捉时间序列数据中的特征。这样可以更全面地理解数据,提高预测的准确性。

具体步骤

输入数据

将包含缺失值的时间序列数据输入到模型中。

编码阶段

模型通过自注意力机制对每个时间点的数据进行编码,生成一个表示该时间点的向量。这个过程中,模型会考虑其他时间点的数据,特别是那些对当前时间点影响较大的数据。

插补阶段

模型根据编码后的向量,预测缺失值。这个过程是通过神经网络的多层计算完成的,最终输出填补后的完整时间序列数据。

输出结果

模型输出填补了缺失值的时间序列数据,可以用于后续的分析或预测任务。

优点

  • 高精度:通过自注意力机制,SAITS 模型可以捕捉到时间序列数据中的长期依赖关系,从而更准确地填补缺失值。
  • 鲁棒性强:即使在数据缺失较多的情况下,SAITS 模型也能表现出较好的性能。
  • 灵活性:可以应用于多种类型的时间序列数据,如金融数据、医疗数据等。

总结

SAITS 模型是一种基于自注意力机制的深度学习模型,专门用于处理时间序列数据中的缺失值。它通过捕捉数据中的长期依赖关系,更准确地预测和填补缺失值,从而提高数据分析的准确性。希望这个解释对你有所帮助!

3. 训练与评估

  1. 优化器‌:使用Adam优化器,学习率1e-3
  2. 评估指标‌:MSE(均方误差)
  3. 早停机制‌:patience=3(连续3轮验证集性能不提升则停止训练)

4. 工程实践

  1. 设备选择‌:自动检测使用CPU或GPU
  2. 错误处理‌:使用try-except捕获并处理可能出现的错误
  3. 结果保存‌:使用pickle保存填补后的数据

学习建议

  1. 先理解数据‌:PhysioNet2012是时间序列数据,了解数据特点有助于理解代码
  2. 分模块学习‌:先掌握数据预处理,再学习模型构建,最后理解训练评估流程
  3. 实践练习‌:尝试修改模型参数(如层数、学习率)观察对结果的影响
  4. 扩展阅读‌:了解Transformer架构和注意力机制,这是SAITS模型的基础

这段代码展示了一个完整的时间序列数据缺失值填补流程,从数据加载到模型训练评估,适合学习如何处理现实世界中的不完整时间序列数据。

以下是SAITS模型训练的具体步骤及关键要点解析:

一、数据准备阶段

  1. 数据集加载与划分

    • 使用preprocess_physionet2012加载医疗时间序列数据,划分为训练集、验证集和测试集4
    • 对缺失值生成指示掩码(indicating_mask),标记人为缺失的位置4
  2. 缺失值预处理

    • 将原始NaN值替换为0作为临时填充
    • 通过np.isnan()检测真实缺失值,异或运算区分自然缺失与人工掩码4

二、模型初始化

  1. 参数配置

    • 设置时间步长(n_steps)和特征维度(n_features
    • 定义Transformer架构参数:层数(n_layers=3)、注意力头数(n_heads=4)、模型维度(d_model=64)4
    • 配置双任务权重(ORT_weight=1MIT_weight=1)4
  2. 硬件适配

    • 自动检测GPU/CPU设备(torch.cuda.is_available()
    • 设置批处理大小(batch_size=32)和并行工作线程数(num_workers=0)5

三、训练流程

  1. 双任务协同训练

    • MIT(掩码插补任务)‌:随机掩码20%已知值,让模型预测被隐藏的真实值4
    • ORT(观测重建任务)‌:强制模型保持已观测值的准确性,防止填补干扰原始数据4
    • 两任务损失加权求和作为总损失函数4
  2. 优化过程

    • 使用Adam优化器(lr=1e-3)进行参数更新4
    • 采用早停机制(patience=3)防止过拟合4
    • 每轮验证集评估,保存最佳模型(model_saving_strategy="best")4

四、评估与部署

  1. 性能验证

    • 在测试集计算MSE指标,对比填补值与真实值差异4
    • 分析不同缺失率下的插补准确率4
  2. 关键组件验证

    • 消融实验证明:双DMSA模块、加权组合层对性能提升至关重要4
    • MIT和ORT任务缺一不可,单独使用任一任务会导致性能下降4

五、注意事项

  1. 数据特性适配

    • 时间序列通常具有高缺失率(PhysioNet2012平均缺失率约80%)
    • 需调整掩码比例匹配实际数据缺失模式4
  2. 工程实践

    • 使用pickle_dump保存预处理数据
    • 通过os.makedirs确保结果目录存在5

该流程通过Transformer架构实现时空特征捕获,结合双任务机制平衡插补准确性与观测值保真度,适用于医疗监测、传感器数据等时序缺失场景。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值