简介
都用库函数实现,首先导入包。
import numpy as np
from sklearn.metrics import roc_curve, f1_score, recall_score, accuracy_score, precision_score
ROC
fpr, tpr, thresholds = roc_curve(true_labels, predicted_probabilities)
# 假设你有模型的预测概率和真实标签
predicted_probabilities = ... # 模型的预测概率
true_labels = ... # 真实标签
注意这里的predicted_probabilities是预测为正例的概率。不是0,1标签,也不是两个概率里的max值!
计算完后,fpr,tpr就是坐标数组。
之后使用使用 fpr、tpr 绘制 ROC 曲线即可。
# 绘制ROC曲线
plt.plot(fpr, tpr, label='ROC curve')
plt.plot([0, 1], [0, 1], 'k--') # 绘制对角线
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC)')
plt.legend(loc='lower right')
plt.show()
四个指标(F1,recall,acc,pre)
# 计算F1得分
predicted_labels = (predicted_probabilities > 0.5).astype(int) # 根据预测概率确定类别
f1 = f1_score(true_labels, predicted_labels)
# 计算召回率
recall = recall_score(true_labels, predicted_labels)
# 计算准确率
accuracy = accuracy_score(true_labels, predicted_labels)
# 计算正确率
precision = precision_score(true_labels, predicted_labels)
# 打印结果
print("F1 Score:", f1)
print("Recall:", recall)
print("Accuracy:", accuracy)
print("Precision:", precision)
注意,这里predicted_labels是0,1标签,不是概率!
完整代码
import numpy as np
from sklearn.metrics import roc_curve, f1_score, recall_score, accuracy_score, precision_score
# 假设你有模型的预测概率和真实标签
predicted_probabilities = ... # 模型的预测概率
true_labels = ... # 真实标签
# 计算ROC曲线
fpr, tpr, thresholds = roc_curve(true_labels, predicted_probabilities)
# 绘制ROC曲线
plt.plot(fpr, tpr, label='ROC curve')
plt.plot([0, 1], [0, 1], 'k--') # 绘制对角线
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC)')
plt.legend(loc='lower right')
plt.show()
# 计算F1得分
predicted_labels = (predicted_probabilities > 0.5).astype(int) # 根据预测概率确定类别
f1 = f1_score(true_labels, predicted_labels)
# 计算召回率
recall = recall_score(true_labels, predicted_labels)
# 计算准确率
accuracy = accuracy_score(true_labels, predicted_labels)
# 计算正确率
precision = precision_score(true_labels, predicted_labels)
# 打印结果
print("F1 Score:", f1)
print("Recall:", recall)
print("Accuracy:", accuracy)
print("Precision:", precision)