-
读取数据
df = pd.read_csv('features.csv') samples = np.array(df)
-
Z-score标准化处理
不对第一列和最后一列处理samples[:, 1:-1] = preprocessing.scale(samples[:, 1:-1])
-
划分特征和标签
xs = samples[:, 1:-1] # 特征 ys = samples[:, -1] # 标签 ys = ys.astype(int) # 不加这一步会在for fold, (train, test) in enumerate(five_folds):这一步报错
-
定义空列表来存储每个fold的性能指标和ROC曲线数据
accuracy_list = [] precision_list = [] recall_list = [] specificity_list = [] fpr_list = [] tpr_list = []
-
5折交叉划分数据集
skfold = StratifiedKFold(n_splits=5, shuffle=True) five_folds = skfold.split(xs, ys)
-
模型训练
for fold, (train, test) in enumerate(five_folds): # 模型训练 print('==========fold:', fold+1, '==========') mlp = MLP(hidden_layer_sizes=(64, 128, 256, 512), solver='sgd', batch_size=10, max_iter=200, learning_rate_init=0.001, verbose=True) mlp.fit(xs[train], ys[train]) # 绘制损失函数曲线 plt.plot(mlp.loss_curve_) img_path = './img/fold{}_loss.png'.format(fold+1) plt.xlabel('Iteration') plt.ylabel('loss') plt.title('Loss Curve') plt.savefig(img_path) # 模型评价 y_pred = mlp.predict(xs[test]) tn, fp, fn, tp = confusion_matrix(ys[test], y_pred).ravel() accuracy = (tp + tn) / (tp + tn + fp + fn) precision = tp / (tp + fp) recall = tp / (tp + fn) specificity = tn / (tn + fp) accuracy_list.append(accuracy) precision_list.append(precision) recall_list.append(recall) specificity_list.append(specificity) # 计算ROC曲线的FPR和TPR并将其添加到ROC曲线数据列表中 fpr, tpr, _ = roc_curve(ys[test], y_pred) fpr_list.append(fpr) tpr_list.append(tpr)
-
计算各指标的平均值
mean_accuracy = np.mean(accuracy_list) mean_precision = np.mean(precision_list) mean_recall = np.mean(recall_list) mean_specificity = np.mean(specificity_list) mean_fpr = fpr mean_tpr = tpr mean_fpr[1] = np.mean([val[1] for val in fpr_list]) mean_tpr[1] = np.mean([val[1] for val in tpr_list]) auc = auc(mean_fpr, mean_tpr)
-
绘制 ROC曲线
plt.plot(mean_fpr, mean_tpr, 'k--', lw=2) plt.xlim([-0.05, 1.05]) # 设置x、y轴的上下限,以免和边缘重合,更好的观察图像的整体 plt.ylim([-0.05, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('ROC Curve')
【CV作业06-3】训练机器学习模型,预测肺结节良恶性
于 2023-04-13 22:38:04 首次发布