train_and_evaluate 代码详解

def train_and_evaluate(model, model_name, loss_fn, X_train_normalized, y_train, X_test_normalized, y_test, epochs=50,
                       batch_size=32):
    """训练模型并计算测试集的 ROC 曲线数据"""
    print(f"\n{'=' * 50}")
    print(f"训练模型: {model_name}")
    print(f"{'=' * 50}")

    # 添加时间戳到模型文件名
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_filename = f"saved_models/{model_name}_{timestamp}.keras"

    # 添加模型检查点回调
    checkpoint = ModelCheckpoint(
        model_filename,
        monitor='val_auc',
        mode='max',
        save_best_only=True,
        verbose=1
    )

    model.compile(optimizer=Adam(learning_rate=1e-4),
                  loss=loss_fn,
                  metrics=['accuracy',
                           tf.keras.metrics.AUC(name='auc'),
                           tf.keras.metrics.Precision(name='precision'),
                           tf.keras.metrics.Recall(name='recall'),
                           tf.keras.metrics.TruePositives(name='tp'),
                           tf.keras.metrics.TrueNegatives(name='tn'),
                           tf.keras.metrics.FalsePositives(name='fp'),
                           tf.keras.metrics.FalseNegatives(name='fn')])

    early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
    history = model.fit(X_train_normalized, y_train,
                        validation_data=(X_val_normalized, y_val),
                        epochs=epochs,
                        batch_size=batch_size,
                        callbacks=[early_stop, checkpoint],
                        verbose=1)

    # 绘制训练历史
    plot_training_history(history, model_name)

    # 获取预测概率和类别
    y_score = model.predict(X_test_normalized).flatten()
    y_pred = (y_score > 0.5).astype(int)

    # 绘制ROC和PR曲线
    roc_auc, avg_precision = plot_roc_pr_curves(y_test, y_score, model_name)

    # 绘制混淆矩阵
    plot_confusion_matrix(y_test, y_pred, classes=['Class 0', 'Class 1'], model_name=model_name)

    # 打印分类报告
    print(f"\n{model_name} 分类报告:")
    print(classification_report(y_test, y_pred, target_names=['Class 0', 'Class 1']))

    # 计算更多评估指标
    tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    f1_score = 2 * tp / (2 * tp + fp + fn)


def train_and_evaluate(model, model_name, loss_fn, X_train_normalized, y_train, X_test_normalized, y_test, epochs=50,
                       batch_size=32):
  • def: Python关键字,定义函数
  • train_and_evaluate: 函数名称,表示"训练和评估"
  • model: 参数,接收要训练的Keras模型对象
  • model_name: 参数,字符串类型,用于标识模型名称
  • loss_fn: 参数,损失函数(如交叉熵或focal loss)
  • X_train_normalized: 参数,标准化后的训练集特征数据
  • y_train: 参数,训练集标签
  • X_test_normalized: 参数,标准化后的测试集特征数据
  • y_test: 参数,测试集标签
  • epochs=50: 默认参数,训练轮数,默认50
  • batch_size=32: 默认参数,批量大小,默认32
    """训练模型并计算测试集的 ROC 曲线数据"""
  • 函数文档字符串,说明函数功能
    print(f"\n{'=' * 50}")
    print(f"训练模型: {model_name}")
    print(f"{'=' * 50}")
  • 打印分隔线和当前训练模型名称,用于输出美化
    # 添加时间戳到模型文件名
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_filename = f"saved_models/{model_name}_{timestamp}.keras"
  • timestamp: 生成当前时间的字符串(格式:年月日_时分秒)
  • model_filename: 构建模型保存路径,包含模型名和时间戳
    # 添加模型检查点回调
    checkpoint = ModelCheckpoint(
        model_filename,
        monitor='val_auc',
        mode='max',
        save_best_only=True,
        verbose=1
    )
  • ModelCheckpoint: Keras回调,用于保存模型
    • model_filename: 保存路径
    • monitor=‘val_auc’: 监控验证集的AUC指标
    • mode=‘max’: 监控指标越大越好
    • save_best_only=True: 只保存最佳模型
    • verbose=1: 显示保存信息
    model.compile(optimizer=Adam(learning_rate=1e-4),
                  loss=loss_fn,
                  metrics=['accuracy',
                           tf.keras.metrics.AUC(name='auc'),
                           tf.keras.metrics.Precision(name='precision'),
                           tf.keras.metrics.Recall(name='recall'),
                           tf.keras.metrics.TruePositives(name='tp'),
                           tf.keras.metrics.TrueNegatives(name='tn'),
                           tf.keras.metrics.FalsePositives(name='fp'),
                           tf.keras.metrics.FalseNegatives(name='fn')])
  • model.compile: 配置模型训练参数
    • optimizer=Adam(learning_rate=1e-4): 使用Adam优化器,学习率0.0001
    • loss=loss_fn: 使用传入的损失函数
    • metrics: 监控指标列表
      • accuracy: 准确率
      • AUC: ROC曲线下面积
      • Precision: 精确率
      • Recall: 召回率
      • True/False Positives/Negatives: 混淆矩阵元素
    early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
  • EarlyStopping: 早停回调
    • monitor=‘val_loss’: 监控验证集损失
    • patience=5: 连续5轮不改善则停止
    • restore_best_weights=True: 恢复最佳权重
    history = model.fit(X_train_normalized, y_train,
                        validation_data=(X_val_normalized, y_val),
                        epochs=epochs,
                        batch_size=batch_size,
                        callbacks=[early_stop, checkpoint],
                        verbose=1)
  • model.fit: 训练模型
    • X_train_normalized/y_train: 训练数据
    • validation_data: 验证数据集
    • epochs: 训练轮数
    • batch_size: 批量大小
    • callbacks: 回调列表(早停和模型检查点)
    • verbose=1: 显示进度条
    # 绘制训练历史
    plot_training_history(history, model_name)
  • 调用自定义函数绘制训练过程中的损失和准确率曲线
    # 获取预测概率和类别
    y_score = model.predict(X_test_normalized).flatten()
    y_pred = (y_score > 0.5).astype(int)
  • y_score: 模型对测试集的预测概率(经过sigmoid输出)
  • y_pred: 将概率转为二分类结果(>0.5为1,否则为0)
    # 绘制ROC和PR曲线
    roc_auc, avg_precision = plot_roc_pr_curves(y_test, y_score, model_name)
  • 调用自定义函数绘制ROC和PR曲线,返回AUC和AP值
    # 绘制混淆矩阵
    plot_confusion_matrix(y_test, y_pred, classes=['Class 0', 'Class 1'], model_name=model_name)
  • 调用自定义函数绘制混淆矩阵
    # 打印分类报告
    print(f"\n{model_name} 分类报告:")
    print(classification_report(y_test, y_pred, target_names=['Class 0', 'Class 1']))
  • 打印sklearn的分类报告(精确率、召回率、F1等)
    # 计算更多评估指标
    tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    f1_score = 2 * tp / (2 * tp + fp + fn)
  • tn/fp/fn/tp: 从混淆矩阵获取真负/假正/假负/真正
  • sensitivity: 敏感度(召回率)= TP/(TP+FN)
  • specificity: 特异度 = TN/(TN+FP)
  • f1_score: F1分数 = 2*(precision*recall)/(precision+recall)

该函数最终会返回包含所有评估指标的字典结果。每个步骤都包含完整的模型训练、评估和可视化流程。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值