def train_and_evaluate(model, model_name, loss_fn, X_train_normalized, y_train, X_test_normalized, y_test, epochs=50,
batch_size=32):
"""训练模型并计算测试集的 ROC 曲线数据"""
print(f"\n{'=' * 50}")
print(f"训练模型: {model_name}")
print(f"{'=' * 50}")
# 添加时间戳到模型文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_filename = f"saved_models/{model_name}_{timestamp}.keras"
# 添加模型检查点回调
checkpoint = ModelCheckpoint(
model_filename,
monitor='val_auc',
mode='max',
save_best_only=True,
verbose=1
)
model.compile(optimizer=Adam(learning_rate=1e-4),
loss=loss_fn,
metrics=['accuracy',
tf.keras.metrics.AUC(name='auc'),
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.Recall(name='recall'),
tf.keras.metrics.TruePositives(name='tp'),
tf.keras.metrics.TrueNegatives(name='tn'),
tf.keras.metrics.FalsePositives(name='fp'),
tf.keras.metrics.FalseNegatives(name='fn')])
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
history = model.fit(X_train_normalized, y_train,
validation_data=(X_val_normalized, y_val),
epochs=epochs,
batch_size=batch_size,
callbacks=[early_stop, checkpoint],
verbose=1)
# 绘制训练历史
plot_training_history(history, model_name)
# 获取预测概率和类别
y_score = model.predict(X_test_normalized).flatten()
y_pred = (y_score > 0.5).astype(int)
# 绘制ROC和PR曲线
roc_auc, avg_precision = plot_roc_pr_curves(y_test, y_score, model_name)
# 绘制混淆矩阵
plot_confusion_matrix(y_test, y_pred, classes=['Class 0', 'Class 1'], model_name=model_name)
# 打印分类报告
print(f"\n{model_name} 分类报告:")
print(classification_report(y_test, y_pred, target_names=['Class 0', 'Class 1']))
# 计算更多评估指标
tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
sensitivity = tp / (tp + fn)
specificity = tn / (tn + fp)
f1_score = 2 * tp / (2 * tp + fp + fn)
def train_and_evaluate(model, model_name, loss_fn, X_train_normalized, y_train, X_test_normalized, y_test, epochs=50,
batch_size=32):
- def: Python关键字,定义函数
- train_and_evaluate: 函数名称,表示"训练和评估"
- model: 参数,接收要训练的Keras模型对象
- model_name: 参数,字符串类型,用于标识模型名称
- loss_fn: 参数,损失函数(如交叉熵或focal loss)
- X_train_normalized: 参数,标准化后的训练集特征数据
- y_train: 参数,训练集标签
- X_test_normalized: 参数,标准化后的测试集特征数据
- y_test: 参数,测试集标签
- epochs=50: 默认参数,训练轮数,默认50
- batch_size=32: 默认参数,批量大小,默认32
"""训练模型并计算测试集的 ROC 曲线数据"""
- 函数文档字符串,说明函数功能
print(f"\n{'=' * 50}")
print(f"训练模型: {model_name}")
print(f"{'=' * 50}")
- 打印分隔线和当前训练模型名称,用于输出美化
# 添加时间戳到模型文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_filename = f"saved_models/{model_name}_{timestamp}.keras"
- timestamp: 生成当前时间的字符串(格式:年月日_时分秒)
- model_filename: 构建模型保存路径,包含模型名和时间戳
# 添加模型检查点回调
checkpoint = ModelCheckpoint(
model_filename,
monitor='val_auc',
mode='max',
save_best_only=True,
verbose=1
)
- ModelCheckpoint: Keras回调,用于保存模型
- model_filename: 保存路径
- monitor=‘val_auc’: 监控验证集的AUC指标
- mode=‘max’: 监控指标越大越好
- save_best_only=True: 只保存最佳模型
- verbose=1: 显示保存信息
model.compile(optimizer=Adam(learning_rate=1e-4),
loss=loss_fn,
metrics=['accuracy',
tf.keras.metrics.AUC(name='auc'),
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.Recall(name='recall'),
tf.keras.metrics.TruePositives(name='tp'),
tf.keras.metrics.TrueNegatives(name='tn'),
tf.keras.metrics.FalsePositives(name='fp'),
tf.keras.metrics.FalseNegatives(name='fn')])
- model.compile: 配置模型训练参数
- optimizer=Adam(learning_rate=1e-4): 使用Adam优化器,学习率0.0001
- loss=loss_fn: 使用传入的损失函数
- metrics: 监控指标列表
- accuracy: 准确率
- AUC: ROC曲线下面积
- Precision: 精确率
- Recall: 召回率
- True/False Positives/Negatives: 混淆矩阵元素
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
- EarlyStopping: 早停回调
- monitor=‘val_loss’: 监控验证集损失
- patience=5: 连续5轮不改善则停止
- restore_best_weights=True: 恢复最佳权重
history = model.fit(X_train_normalized, y_train,
validation_data=(X_val_normalized, y_val),
epochs=epochs,
batch_size=batch_size,
callbacks=[early_stop, checkpoint],
verbose=1)
- model.fit: 训练模型
- X_train_normalized/y_train: 训练数据
- validation_data: 验证数据集
- epochs: 训练轮数
- batch_size: 批量大小
- callbacks: 回调列表(早停和模型检查点)
- verbose=1: 显示进度条
# 绘制训练历史
plot_training_history(history, model_name)
- 调用自定义函数绘制训练过程中的损失和准确率曲线
# 获取预测概率和类别
y_score = model.predict(X_test_normalized).flatten()
y_pred = (y_score > 0.5).astype(int)
- y_score: 模型对测试集的预测概率(经过sigmoid输出)
- y_pred: 将概率转为二分类结果(>0.5为1,否则为0)
# 绘制ROC和PR曲线
roc_auc, avg_precision = plot_roc_pr_curves(y_test, y_score, model_name)
- 调用自定义函数绘制ROC和PR曲线,返回AUC和AP值
# 绘制混淆矩阵
plot_confusion_matrix(y_test, y_pred, classes=['Class 0', 'Class 1'], model_name=model_name)
- 调用自定义函数绘制混淆矩阵
# 打印分类报告
print(f"\n{model_name} 分类报告:")
print(classification_report(y_test, y_pred, target_names=['Class 0', 'Class 1']))
- 打印sklearn的分类报告(精确率、召回率、F1等)
# 计算更多评估指标
tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
sensitivity = tp / (tp + fn)
specificity = tn / (tn + fp)
f1_score = 2 * tp / (2 * tp + fp + fn)
- tn/fp/fn/tp: 从混淆矩阵获取真负/假正/假负/真正
- sensitivity: 敏感度(召回率)= TP/(TP+FN)
- specificity: 特异度 = TN/(TN+FP)
- f1_score: F1分数 = 2*(precision*recall)/(precision+recall)
该函数最终会返回包含所有评估指标的字典结果。每个步骤都包含完整的模型训练、评估和可视化流程。