一、PR曲线
PR 曲线描述了准确率(Precision)和召回率(Recall)之间的关系。准确率表示预测为正例的样本中真正例的比例,而召回率表示真正例样本中被正确预测为正例的比例。PR 曲线横轴表示召回率,纵轴表示准确率。PR 曲线越靠近左上角,说明模型性能越好。
1.混淆矩阵
真正例(True Positive,TP):指正确分类成为正的样本数,实际为正,预测为正
伪反例(False Positive,FP): 指错误分类为正的样本数,实际为负,预测为正
伪反例(False Negative,FN):指错误分类为负的样本数,实际为正,预测为负
真反例(True Negative,TN):指正确分类为负的样本数,实际为负,预测为负
2.查准率与查全率
查准率P(Precision)是指在所有预测为正例的样本中,真正例的比例。它衡量了模型在预测为正例的样本中的准确性。
定义为:
查全率R(Recall)是指在所有实际为正例的样本中,被正确预测为正例的比例。它衡量了模型对于正例样本的覆盖程度。
定义为:
二、ROC曲线
ROC 曲线展示了真正率(True Positive Rate,TPR)与假正率(False Positive Rate,FPR)之间的关系。TPR也称为召回率或灵敏度,表示正确预测为正例的样本占所有实际正例样本的比例。FPR则表示被错误预测为正例的负例样本占所有实际负例样本的比例。ROC 曲线横轴表示 FPR,纵轴表示 TPR。ROC 曲线越靠近左上角,说明模型性能越好。
TPR与FPR
TPR(真正例率,也称为召回率、敏感度):TPR是指在所有实际为正例的样本中,被正确预测为正例的比例。它衡量了模型对于正例样本的覆盖程度,即模型正确预测为正例的能力。
定义为:
FPR(假正例率):FPR是指在所有实际为负例的样本中,被错误预测为正例的比例。它衡量了模型在负例样本中误报正例的能力。
定义为:
三、代码实现
本次任务使用乳腺癌数据集,绘制PR曲线和ROC曲线。
1.导入数据
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.linear_model import LogisticRegression
2.加载乳腺癌数据集
data = load_breast_cancer(as_frame=True)
X = data.data
y = data.target
3.创建模型并训练
model = LogisticRegression()
model.fit(X, y)
4.对样本进行预测,并计算预测得分
y_pred = model.predict(X)
scores = model.decision_function(X)
5.计算准确率、召回率和阈值
def find_threshold(y_true, scores):
fpr_list = []
tpr_list = []
precision_list = []
recall_list = []
thresholds = np.unique(scores)
for threshold in thresholds:
y_pred = np.zeros_like(y_true)
y_pred[scores >= threshold] = 1
tp = np.sum((y_true == 1) & (y_pred == 1))
fp = np.sum((y_true == 0) & (y_pred == 1))
fn = np.sum((y_true == 1) & (y_pred == 0))
tn = np.sum((y_true == 0) & (y_pred == 0))
fpr = fp / (fp + tn)
tpr = tp / (tp + fn)
precision = tp / (tp + fp)
recall = tp / (tp + fn)
fpr_list.append(fpr)
tpr_list.append(tpr)
precision_list.append(precision)
recall_list.append(recall)
return fpr_list, tpr_list, precision_list, recall_list, thresholds
fpr_list, tpr_list, precision_list, recall_list, thresholds = find_threshold(y, scores)
6.绘制 ROC 曲线
plt.plot(fpr_list, tpr_list)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.show()
7.绘制 PR 曲线
plt.plot(recall_list, precision_list)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('PR Curve')
plt.show()
8.输出结果
四、小结
在上述代码中,首先加载乳腺癌数据集,并使用逻辑回归模型进行训练和预测。然后,自定义了 find_threshold 函数来计算不同阈值下的准确率、召回率、假正率和真正率,并将它们保存在 precision_list、recall_list、fpr_list和 tpr_list中。最后,使用 matplotlib 库绘制 ROC 曲线和 PR 曲线。
总体思路是,首先对预测结果根据分数从小到大排序,然后对于每一个分数作为阈值,分别计算真正率、假正率、准确率和召回率,并记录下来。最后,将它们画在坐标系上即可。