import logging
from benchpots.datasets import preprocess_physionet2012
from pypots.classification import BRITS
from pypots.nn.functional.classification import calc_binary_classification_metrics
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def load_data(subset="set-a", pattern="point", rate=0.1):
try:
dataset = preprocess_physionet2012(subset=subset, pattern=pattern, rate=rate)
return {
"train": {"X": dataset['train_X'], "y": dataset['train_y']},
"val": {"X": dataset['val_X'], "y": dataset['val_y']},
"test": {"X": dataset['test_X'], "y": dataset['test_y']}
}
except KeyError as e:
logging.error(f"数据加载失败: 缺少键值 {e}")
raise
except Exception as e:
logging.error(f"数据加载失败: {e}")
raise
def train_model(dataset, n_steps, n_features, n_classes):
try:
model = BRITS(
n_steps=n_steps,
n_features=n_features,
n_classes=n_classes,
rnn_hidden_size=128,
epochs=20,
patience=5
)
model.fit(dataset["train"], dataset["val"])
return model
except Exception as e:
logging.error(f"模型训练失败: {e}")
raise
def evaluate_model(model, test_dataset):
try:
results = model.predict(test_dataset)
prediction = results["classification"]
metrics = calc_binary_classification_metrics(prediction, test_dataset["y"])
logging.info(f"测试集 ROC-AUC: {metrics['roc_auc']:.4f}")
logging.info(f"测试集 PR-AUC: {metrics['pr_auc']:.4f}")
return metrics
except Exception as e:
logging.error(f"模型评估失败: {e}")
raise
if __name__ == "__main__":
# 1. 加载数据
dataset = load_data(subset="set-a", pattern="point", rate=0.1)
# 2. 训练模型
model = train_model(
dataset,
n_steps=dataset["train"]["X"].shape[1],
n_features=dataset["train"]["X"].shape[2],
n_classes=len(set(dataset["train"]["y"]))
)
# 3. 评估模型
metrics = evaluate_model(model, dataset["test"])
print(f"BRITS 在测试集上的 ROC-AUC: {metrics['roc_auc']:.4f}")
print(f"BRITS 在测试集上的 PR-AUC: {metrics['pr_auc']:.4f}")
代码功能概述
该代码的主要功能是使用 BRITS 模型对 PhysioNet2012 数据集进行分类任务,并通过模块化设计实现数据加载、模型训练和评估的过程。代码分为以下几个主要部分:
1. 日志配置:设置日志记录机制。
2. 数据加载函数:封装数据加载逻辑,便于复用。
3. 模型训练函数:初始化并训练 BRITS 模型。
4. 模型评估函数:计算模型在测试集上的性能指标(ROC-AUC 和 PR-AUC)。
5. 主程序逻辑:调用上述函数完成整个流程。
代码详细解析
1. 日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
功能:配置日志记录机制,设置日志级别为 INFO,格式为时间戳、日志级别和消息内容。
作用:
提供统一的日志管理方式,方便调试和监控。
替代直接打印日志信息,便于扩展(如写入文件或远程发送日志)。
2. 数据加载函数
def load_data(subset="set-a", pattern="point", rate=0.1):
try:
dataset = preprocess_physionet2012(subset=subset, pattern=pattern, rate=rate)
return {
"train": {"X": dataset['train_X'], "y": dataset['train_y']},
"val": {"X": dataset['val_X'], "y": dataset['val_y']},
"test": {"X": dataset['test_X'], "y": dataset['test_y']}
}
except KeyError as e:
logging.error(f"数据加载失败: 缺少键值 {e}")
raise
except Exception as e:
logging.error(f"数据加载失败: {e}")
raise
功能:加载并划分 PhysioNet2012 数据集为训练集、验证集和测试集。
参数:
subset="set-a":指定加载的数据子集。
pattern="point":指定缺失模式(点状缺失)。
rate=0.1:验证集和测试集中遮蔽的比例。
返回值:
返回一个字典,包含训练集、验证集和测试集的数据。
异常处理:
如果数据中缺少必要键值(如 train_X),捕获 KeyError 并记录错误日志。
捕获其他未知异常并记录错误日志。
3. 模型训练函数
def train_model(dataset, n_steps, n_features, n_classes):
try:
model = BRITS(
n_steps=n_steps,
n_features=n_features,
n_classes=n_classes,
rnn_hidden_size=128,
epochs=20,
patience=5
)
model.fit(dataset["train"], dataset["val"])
return model
except Exception as e:
logging.error(f"模型训练失败: {e}")
raise
功能:初始化 BRITS模型并进行训练。
参数:
dataset:包含训练集和验证集的数据字典。
n_steps:时间序列的时间步数。
n_features:每个时间步的特征数。
n_classes:分类任务的类别数。
模型参数:
rnn_hidden_size=128:RNN 隐藏层大小。
epochs=20:最大训练轮数。
patience=5:早停机制的耐心值。
返回值:返回训练好的 BRITS模型。
异常处理:捕获训练过程中可能出现的异常并记录错误日志。
4. 模型评估函数
def evaluate_model(model, test_dataset):
try:
results = model.predict(test_dataset)
prediction = results["classification"]
metrics = calc_binary_classification_metrics(prediction, test_dataset["y"])
logging.info(f"测试集 ROC-AUC: {metrics['roc_auc']:.4f}")
logging.info(f"测试集 PR-AUC: {metrics['pr_auc']:.4f}")
return metrics
except Exception as e:
logging.error(f"模型评估失败: {e}")
raise
功能:对测试集进行预测,并计算性能指标(ROC-AUC 和 PR-AUC)。
参数:
model:训练好的 BRITS 模型。
test_dataset:测试集数据。
逻辑:
1. 使用 model.predict 方法对测试集进行预测。
2. 提取预测结果中的分类标签 prediction。
3. 使用 calc_binary_classification_metrics计算 ROC-AUC 和 PR-AUC。
4. 记录并返回评估结果。
异常处理:捕获评估过程中可能出现的异常并记录错误日志。
5. 主程序逻辑
if __name__ == "__main__":
# 1. 加载数据
dataset = load_data(subset="set-a", pattern="point", rate=0.1)
# 2. 训练模型
model = train_model(
dataset,
n_steps=dataset["train"]["X"].shape[1],
n_features=dataset["train"]["X"].shape[2],
n_classes=len(set(dataset["train"]["y"]))
)
# 3. 评估模型
metrics = evaluate_model(model, dataset["test"])
print(f"BRITS 在测试集上的 ROC-AUC: {metrics['roc_auc']:.4f}")
print(f"BRITS 在测试集上的 PR-AUC: {metrics['pr_auc']:.4f}")
功能:调用上述函数完成整个流程。
步骤:
1. 加载数据:
调用 load_data 函数加载 PhysioNet2012数据集。
2. 训练模型:
调用 train_model函数初始化并训练 BRITS 模型。
从训练集中提取 n_steps(时间步数)、n_features(特征数)和 n_classes(类别数)。
3. 评估模型:
调用 evaluate_model函数对测试集进行预测并计算性能指标。
打印 ROC-AUC 和 PR-AUC 的评估结果。
代码优点
模块化设计:
将数据加载、模型训练和评估封装为独立函数,便于复用和调试。
日志管理:
使用 logging 模块替代直接打印日志信息,提供统一的日志管理方式。
鲁棒性:
增加了异常处理机制,能够捕获常见错误(如键值缺失或运行时错误)。
灵活性:
通过参数化配置,可以轻松调整数据集、模型超参数和训练策略。
改进建议
仍有一些改进方向:
超参数管理:
可以使用配置文件(如 JSON 或 YAML)来管理超参数,避免硬编码。
GPU 支持:
如果适用,可以增加 GPU 支持以加速模型训练和推理过程。
扩展性:
可以将模型替换为其他时间序列分类模型(如 GRU-D 或 Transformer),增加代码的通用性。
可视化:
增加训练过程的可视化(如损失曲线或准确率曲线),便于分析模型性能。
总结
以上代码通过模块化设计和日志管理,实现了从数据加载到模型评估的完整流程。代码结构清晰、易于维护,并且具有良好的扩展性和鲁棒性。
学习建议
以下是基于以上代码提供的系统化学习建议。这些建议涵盖了从基础知识到高级应用的多个层面,帮助学习者逐步掌握时间序列分类任务的相关技能。
学习计划表格
阶段 | 目标 | 学习内容 | 实践建议 |
---|---|---|---|
1. 学习基础知识 | 了解时间序列数据的特点及其在实际问题中的应用 | 时间序列的基本概念(如时间步、特征维度等),缺失值处理方法(如插值、遮蔽等),常见的时间序列任务类型(如分类、回归、异常检测等) | 使用 pandas 或 numpy 加载和可视化时间序列数据,尝试对时间序列数据进行预处理(如标准化、缺失值填充等) |
熟练使用 Python 进行数据处理和模型开发 | 数据结构(如列表、字典、数组),文件操作(如读取 CSV 文件),函数定义与调用 | 编写简单的函数实现数据加载和预处理功能,练习异常处理机制(如 try-except 块) | |
理解机器学习的基本概念和常见算法 | 监督学习与无监督学习的区别,分类任务的基本评估指标(如准确率、ROC-AUC、PR-AUC),模型训练的基本流程(如数据划分、模型选择、超参数调优) | 使用 scikit-learn 实现一个简单的分类任务,计算并绘制 ROC 曲线和 PR 曲线 | |
2. 学习时间序列分类模型 | 掌握 BRITS 模型的工作原理及其适用场景 | 双向 RNN 的基本概念,BRITS 模型如何处理带缺失值的时间序列数据,BRITS 的输入输出格式及超参数设置 | 阅读 BRITS 模型的原始论文或相关资料,使用 pypots 库中的 BRITS 模型完成一个简单的分类任务 |
扩展知识面,了解其他时间序列分类模型 | 基于 RNN 的模型(如 LSTM、GRU),基于注意力机制的模型(如 Transformer),特殊设计的模型(如 GRU-D、MTAD-GAT) | 使用 tensorflow 或 pytorch 实现一个简单的 RNN 或 Transformer 模型,对比不同模型在相同任务上的性能表现 | |
3. 学习模块化编程与日志管理 | 掌握将复杂任务分解为多个小模块的能力 | 函数封装的意义和方法,参数化配置的使用(如通过函数参数传递超参数),模块之间的依赖关系和调用逻辑 | 将代码进一步拆分为更小的模块(如数据增强、模型保存等),使用配置文件(如 JSON 或 YAML)管理超参数 |
学会记录程序运行过程中的信息以方便调试和监控 | logging 模块的基本用法,不同日志级别的含义(如 INFO 、WARNING 、ERROR ),如何将日志写入文件或远程服务器 | 在项目中加入日志记录功能,尝试将日志写入文件并分析其内容 | |
4. 学习性能优化与扩展 | 提高模型训练和推理的效率 | GPU 加速的基本原理,如何在代码中启用 GPU 支持,超参数调优的方法(如网格搜索、贝叶斯优化) | 在支持 GPU 的环境中运行代码并比较速度差异,使用 optuna 或 hyperopt 工具进行超参数调优 |
使代码能够适应不同的任务和需求 | 如何替换模型以支持其他任务(如回归或异常检测),如何扩展数据预处理逻辑以支持更多数据集,如何将代码封装为可复用的库或工具 | 尝试将代码应用于其他时间序列数据集(如 UCI 数据集),将代码封装为一个 Python 包并发布到 PyPI | |
5. 学习可视化与报告生成 | 通过可视化手段更好地理解数据和模型性能 | 时间序列数据的可视化方法(如折线图、热力图),模型训练过程的可视化(如损失曲线、准确率曲线) | 使用 matplotlib 或 seaborn 绘制时间序列数据,使用 tensorboard 可视化模型训练过程 |
总结实验结果并生成可共享的报告 | Markdown 和 Jupyter Notebook 的基本用法,如何将代码、图表和文字说明整合到一个文档中 | 使用 Jupyter Notebook 记录实验过程和结果,将最终报告导出为 PDF 或 HTML 格式 | |
6. 学习开源社区与协作 | 通过参与开源项目提升实战能力 | 如何阅读和理解开源项目的代码,如何提交代码贡献(如修复 bug 或添加新功能) | 浏览 pypots 或其他时间序列相关的开源项目,提交一个小的功能改进或文档更新 |
掌握团队协作开发的基本技能 | Git 的基本用法(如分支管理、代码合并),如何编写清晰的代码注释和文档 | 使用 GitHub 创建一个时间序列分类项目,邀请朋友或同事一起开发和测试代码 |
通过以上学习路径,可以从零开始逐步掌握时间序列分类任务的相关知识和技能。建议按照以下步骤推进:
1. 夯实基础:从时间序列数据和 Python 编程基础学起。
2. 深入模型:重点学习 BRITS 模型及其应用场景。
3. 提升能力:掌握模块化编程、日志管理和性能优化技巧。
4. 扩展视野:尝试其他模型和任务,积累实践经验。
5. 分享成果:通过可视化和报告生成展示您的研究成果。
推荐阅读材料
Cao, W., Wang, D., Li, J., Zhou, H., Li, L., & Li, Y. (2018). BRITS: Bidirectional Recurrent Imputation for Time Series. NeurIPS 2018.
推荐原因:作为时序数据插补领域的里程碑式研究,本文在GRU-D和M-RNN等经典模型的基础上进行了创新性改进,显著提升了模型性能。该研究发表于人工智能顶级会议NeurIPS 2018,截至2025年5月,其Google Scholar引用量已突破800次,充分体现了其在学术界的重要影响力。