BRITS模型在PhysioNet2012数据集上的应用

import logging
from benchpots.datasets import preprocess_physionet2012
from pypots.classification import BRITS
from pypots.nn.functional.classification import calc_binary_classification_metrics

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def load_data(subset="set-a", pattern="point", rate=0.1):
    try:
        dataset = preprocess_physionet2012(subset=subset, pattern=pattern, rate=rate)
        return {
            "train": {"X": dataset['train_X'], "y": dataset['train_y']},
            "val": {"X": dataset['val_X'], "y": dataset['val_y']},
            "test": {"X": dataset['test_X'], "y": dataset['test_y']}
        }
    except KeyError as e:
        logging.error(f"数据加载失败: 缺少键值 {e}")
        raise
    except Exception as e:
        logging.error(f"数据加载失败: {e}")
        raise

def train_model(dataset, n_steps, n_features, n_classes):
    try:
        model = BRITS(
            n_steps=n_steps,
            n_features=n_features,
            n_classes=n_classes,
            rnn_hidden_size=128,
            epochs=20,
            patience=5
        )
        model.fit(dataset["train"], dataset["val"])
        return model
    except Exception as e:
        logging.error(f"模型训练失败: {e}")
        raise

def evaluate_model(model, test_dataset):
    try:
        results = model.predict(test_dataset)
        prediction = results["classification"]
        metrics = calc_binary_classification_metrics(prediction, test_dataset["y"])
        logging.info(f"测试集 ROC-AUC: {metrics['roc_auc']:.4f}")
        logging.info(f"测试集 PR-AUC: {metrics['pr_auc']:.4f}")
        return metrics
    except Exception as e:
        logging.error(f"模型评估失败: {e}")
        raise

if __name__ == "__main__":
    # 1. 加载数据
    dataset = load_data(subset="set-a", pattern="point", rate=0.1)

    # 2. 训练模型
    model = train_model(
        dataset,
        n_steps=dataset["train"]["X"].shape[1],
        n_features=dataset["train"]["X"].shape[2],
        n_classes=len(set(dataset["train"]["y"]))
    )

    # 3. 评估模型
    metrics = evaluate_model(model, dataset["test"])
    print(f"BRITS 在测试集上的 ROC-AUC: {metrics['roc_auc']:.4f}")
    print(f"BRITS 在测试集上的 PR-AUC: {metrics['pr_auc']:.4f}")

代码功能概述

该代码的主要功能是使用 BRITS 模型对 PhysioNet2012 数据集进行分类任务,并通过模块化设计实现数据加载、模型训练和评估的过程。代码分为以下几个主要部分:

1. 日志配置:设置日志记录机制。
2. 数据加载函数:封装数据加载逻辑,便于复用。
3. 模型训练函数:初始化并训练 BRITS 模型。
4. 模型评估函数:计算模型在测试集上的性能指标(ROC-AUC 和 PR-AUC)。
5. 主程序逻辑:调用上述函数完成整个流程。

代码详细解析

1. 日志配置
 

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')



功能:配置日志记录机制,设置日志级别为 INFO,格式为时间戳、日志级别和消息内容。
作用:
        提供统一的日志管理方式,方便调试和监控。
        替代直接打印日志信息,便于扩展(如写入文件或远程发送日志)。

2. 数据加载函数

def load_data(subset="set-a", pattern="point", rate=0.1):
    try:
        dataset = preprocess_physionet2012(subset=subset, pattern=pattern, rate=rate)
        return {
            "train": {"X": dataset['train_X'], "y": dataset['train_y']},
            "val": {"X": dataset['val_X'], "y": dataset['val_y']},
            "test": {"X": dataset['test_X'], "y": dataset['test_y']}
        }
    except KeyError as e:
        logging.error(f"数据加载失败: 缺少键值 {e}")
        raise
    except Exception as e:
        logging.error(f"数据加载失败: {e}")
        raise

功能:加载并划分 PhysioNet2012 数据集为训练集、验证集和测试集。
参数:
        subset="set-a":指定加载的数据子集。
        pattern="point":指定缺失模式(点状缺失)。
        rate=0.1:验证集和测试集中遮蔽的比例。
返回值:
        返回一个字典,包含训练集、验证集和测试集的数据。
异常处理:
        如果数据中缺少必要键值(如 train_X),捕获 KeyError 并记录错误日志。
        捕获其他未知异常并记录错误日志。

3. 模型训练函数

def train_model(dataset, n_steps, n_features, n_classes):
    try:
        model = BRITS(
            n_steps=n_steps,
            n_features=n_features,
            n_classes=n_classes,
            rnn_hidden_size=128,
            epochs=20,
            patience=5
        )
        model.fit(dataset["train"], dataset["val"])
        return model
    except Exception as e:
        logging.error(f"模型训练失败: {e}")
        raise

功能:初始化 BRITS模型并进行训练。
参数:
        dataset:包含训练集和验证集的数据字典。
        n_steps:时间序列的时间步数。
        n_features:每个时间步的特征数。
        n_classes:分类任务的类别数。
模型参数:
        rnn_hidden_size=128:RNN 隐藏层大小。
        epochs=20:最大训练轮数。
        patience=5:早停机制的耐心值。

返回值:返回训练好的 BRITS模型。
异常处理:捕获训练过程中可能出现的异常并记录错误日志。

4. 模型评估函数

def evaluate_model(model, test_dataset):
    try:
        results = model.predict(test_dataset)
        prediction = results["classification"]
        metrics = calc_binary_classification_metrics(prediction, test_dataset["y"])
        logging.info(f"测试集 ROC-AUC: {metrics['roc_auc']:.4f}")
        logging.info(f"测试集 PR-AUC: {metrics['pr_auc']:.4f}")
        return metrics
    except Exception as e:
        logging.error(f"模型评估失败: {e}")
        raise

功能:对测试集进行预测,并计算性能指标(ROC-AUC 和 PR-AUC)。
参数:
        model:训练好的 BRITS 模型。
        test_dataset:测试集数据。
逻辑:
  1. 使用 model.predict 方法对测试集进行预测。
  2. 提取预测结果中的分类标签 prediction。
  3. 使用 calc_binary_classification_metrics计算 ROC-AUC 和 PR-AUC。
  4. 记录并返回评估结果。
异常处理:捕获评估过程中可能出现的异常并记录错误日志。


5. 主程序逻辑

if __name__ == "__main__":
    # 1. 加载数据
    dataset = load_data(subset="set-a", pattern="point", rate=0.1)

    # 2. 训练模型
    model = train_model(
        dataset,
        n_steps=dataset["train"]["X"].shape[1],
        n_features=dataset["train"]["X"].shape[2],
        n_classes=len(set(dataset["train"]["y"]))
    )

    # 3. 评估模型
    metrics = evaluate_model(model, dataset["test"])
    print(f"BRITS 在测试集上的 ROC-AUC: {metrics['roc_auc']:.4f}")
    print(f"BRITS 在测试集上的 PR-AUC: {metrics['pr_auc']:.4f}")

功能:调用上述函数完成整个流程。
步骤:
  1. 加载数据:
     调用 load_data 函数加载 PhysioNet2012数据集。
  2. 训练模型:
     调用 train_model函数初始化并训练 BRITS 模型。
     从训练集中提取 n_steps(时间步数)、n_features(特征数)和 n_classes(类别数)。
  3. 评估模型:
     调用 evaluate_model函数对测试集进行预测并计算性能指标。
     打印 ROC-AUC 和 PR-AUC 的评估结果。

代码优点


模块化设计:
        将数据加载、模型训练和评估封装为独立函数,便于复用和调试。
日志管理:
        使用 logging 模块替代直接打印日志信息,提供统一的日志管理方式。
鲁棒性:
        增加了异常处理机制,能够捕获常见错误(如键值缺失或运行时错误)。
灵活性:
        通过参数化配置,可以轻松调整数据集、模型超参数和训练策略。

改进建议

仍有一些改进方向:

超参数管理:
        可以使用配置文件(如 JSON 或 YAML)来管理超参数,避免硬编码。
GPU 支持:
        如果适用,可以增加 GPU 支持以加速模型训练和推理过程。
扩展性:
        可以将模型替换为其他时间序列分类模型(如 GRU-D 或 Transformer),增加代码的通用性。
可视化:
        增加训练过程的可视化(如损失曲线或准确率曲线),便于分析模型性能。


总结

以上代码通过模块化设计和日志管理,实现了从数据加载到模型评估的完整流程。代码结构清晰、易于维护,并且具有良好的扩展性和鲁棒性。

学习建议

以下是基于以上代码提供的系统化学习建议。这些建议涵盖了从基础知识到高级应用的多个层面,帮助学习者逐步掌握时间序列分类任务的相关技能。

学习计划表格

阶段目标学习内容实践建议
1. 学习基础知识了解时间序列数据的特点及其在实际问题中的应用时间序列的基本概念(如时间步、特征维度等),缺失值处理方法(如插值、遮蔽等),常见的时间序列任务类型(如分类、回归、异常检测等)使用 pandasnumpy 加载和可视化时间序列数据,尝试对时间序列数据进行预处理(如标准化、缺失值填充等)
熟练使用 Python 进行数据处理和模型开发数据结构(如列表、字典、数组),文件操作(如读取 CSV 文件),函数定义与调用编写简单的函数实现数据加载和预处理功能,练习异常处理机制(如 try-except 块)
理解机器学习的基本概念和常见算法监督学习与无监督学习的区别,分类任务的基本评估指标(如准确率、ROC-AUC、PR-AUC),模型训练的基本流程(如数据划分、模型选择、超参数调优)使用 scikit-learn 实现一个简单的分类任务,计算并绘制 ROC 曲线和 PR 曲线
2. 学习时间序列分类模型掌握 BRITS 模型的工作原理及其适用场景双向 RNN 的基本概念,BRITS 模型如何处理带缺失值的时间序列数据,BRITS 的输入输出格式及超参数设置阅读 BRITS 模型的原始论文或相关资料,使用 pypots 库中的 BRITS 模型完成一个简单的分类任务
扩展知识面,了解其他时间序列分类模型基于 RNN 的模型(如 LSTM、GRU),基于注意力机制的模型(如 Transformer),特殊设计的模型(如 GRU-D、MTAD-GAT)使用 tensorflowpytorch 实现一个简单的 RNN 或 Transformer 模型,对比不同模型在相同任务上的性能表现
3. 学习模块化编程与日志管理掌握将复杂任务分解为多个小模块的能力函数封装的意义和方法,参数化配置的使用(如通过函数参数传递超参数),模块之间的依赖关系和调用逻辑将代码进一步拆分为更小的模块(如数据增强、模型保存等),使用配置文件(如 JSON 或 YAML)管理超参数
学会记录程序运行过程中的信息以方便调试和监控logging 模块的基本用法,不同日志级别的含义(如 INFOWARNINGERROR),如何将日志写入文件或远程服务器在项目中加入日志记录功能,尝试将日志写入文件并分析其内容
4. 学习性能优化与扩展提高模型训练和推理的效率GPU 加速的基本原理,如何在代码中启用 GPU 支持,超参数调优的方法(如网格搜索、贝叶斯优化)在支持 GPU 的环境中运行代码并比较速度差异,使用 optunahyperopt 工具进行超参数调优
使代码能够适应不同的任务和需求如何替换模型以支持其他任务(如回归或异常检测),如何扩展数据预处理逻辑以支持更多数据集,如何将代码封装为可复用的库或工具尝试将代码应用于其他时间序列数据集(如 UCI 数据集),将代码封装为一个 Python 包并发布到 PyPI
5. 学习可视化与报告生成通过可视化手段更好地理解数据和模型性能时间序列数据的可视化方法(如折线图、热力图),模型训练过程的可视化(如损失曲线、准确率曲线)使用 matplotlibseaborn 绘制时间序列数据,使用 tensorboard 可视化模型训练过程
总结实验结果并生成可共享的报告Markdown 和 Jupyter Notebook 的基本用法,如何将代码、图表和文字说明整合到一个文档中使用 Jupyter Notebook 记录实验过程和结果,将最终报告导出为 PDF 或 HTML 格式
6. 学习开源社区与协作通过参与开源项目提升实战能力如何阅读和理解开源项目的代码,如何提交代码贡献(如修复 bug 或添加新功能)浏览 pypots 或其他时间序列相关的开源项目,提交一个小的功能改进或文档更新
掌握团队协作开发的基本技能Git 的基本用法(如分支管理、代码合并),如何编写清晰的代码注释和文档使用 GitHub 创建一个时间序列分类项目,邀请朋友或同事一起开发和测试代码

通过以上学习路径,可以从零开始逐步掌握时间序列分类任务的相关知识和技能。建议按照以下步骤推进:

1. 夯实基础:从时间序列数据和 Python 编程基础学起。
2. 深入模型:重点学习 BRITS 模型及其应用场景。
3. 提升能力:掌握模块化编程、日志管理和性能优化技巧。
4. 扩展视野:尝试其他模型和任务,积累实践经验。
5. 分享成果:通过可视化和报告生成展示您的研究成果。

推荐阅读材料

Cao, W., Wang, D., Li, J., Zhou, H., Li, L., & Li, Y. (2018). BRITS: Bidirectional Recurrent Imputation for Time Series. NeurIPS 2018.

推荐原因:作为时序数据插补领域的里程碑式研究,本文在GRU-D和M-RNN等经典模型的基础上进行了创新性改进,显著提升了模型性能。该研究发表于人工智能顶级会议NeurIPS 2018,截至2025年5月,其Google Scholar引用量已突破800次,充分体现了其在学术界的重要影响力。

Traceback (most recent call last): File "D:/air/数据缺失填充/BRITS-Air-Quality-main - 4 - 副本/BRITS-Air-Quality-main/Air-Quality/main.py", line 156, in <module> LOSS_train, MAE_train, MRE_train, MAE_test, MRE_test = run() File "D:/air/数据缺失填充/BRITS-Air-Quality-main - 4 - 副本/BRITS-Air-Quality-main/Air-Quality/main.py", line 144, in run LOSS_train, MAE_train, MRE_train = train(model,train_data_iter) File "D:/air/数据缺失填充/BRITS-Air-Quality-main - 4 - 副本/BRITS-Air-Quality-main/Air-Quality/main.py", line 53, in train ret = model.run_on_batch(data, optimizer, epoch) File "D:\air\数据缺失填充\BRITS-Air-Quality-main - 4 - 副本\BRITS-Air-Quality-main\Air-Quality\models\aseq.py", line 171, in run_on_batch ret = self(data) File "D:\anaconda3\envs\pytorch-gpu2\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "D:\air\数据缺失填充\BRITS-Air-Quality-main - 4 - 副本\BRITS-Air-Quality-main\Air-Quality\models\aseq.py", line 63, in forward encoder_out = self.encoder(data) File "D:\anaconda3\envs\pytorch-gpu2\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "D:\air\数据缺失填充\BRITS-Air-Quality-main - 4 - 副本\BRITS-Air-Quality-main\Air-Quality\models\brits.py", line 38, in forward ret_f = self.rits_f(data, 'forward') File "D:\anaconda3\envs\pytorch-gpu2\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "D:\air\数据缺失填充\BRITS-Air-Quality-main - 4 - 副本\BRITS-Air-Quality-main\Air-Quality\models\rits.py", line 174, in forward h = h * gamma_h RuntimeError: The size of tensor a (14) must match the size of tensor b (64) at non-singleton dimension 0 进程已结束,退出代码 1
07-25
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值