【Python_Code】
"""
官方文档:https://scikit-learn.org/stable/modules/model_evaluation.html
时间:2021年08月22日15:26:40
作者:陈嘿萌
"""
import matplotlib.pyplot as plt
import sklearn.metrics as metrics
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve
if __name__ == '__main__':
# 1.统计报告:precision recall f1-score accuracy support
# ================================================================= #
# y_true : 真实标签 y_fit = 预测结果 y_prob:预测结果的概率
# ================================================================= #
# target_name : 类标签
print("#=============================1.support===============================#")
y_true = [0, 0, 1, 1, 2, 2, 3, 3]
y_fit = [0, 1, 1, 1, 2, 3, 3, 0]
y_prob = [0.5, 0.6, 0.7, 0.8, 0.9, 0.4, 0.3, 0.1]
target_names = ['class1', 'class2', 'class3', 'class4']
support = classification_report(y_true, y_fit, target_names=target_names)
print(support)
# support是str类型,不方便取值,如果想要取值的话可以 转换成字典类型:output_dict = True
support_dict = classification_report(y_true, y_fit, target_names=target_names, output_dict=True)
# 转换成字典取值
for k, v in support_dict.items():
print(k, v)
print("accuracy:", support_dict['accuracy'])
# 单独求准确率指标
accuracy = accuracy_score(y_true, y_fit)
print("accuracy:", accuracy)
# 2.计算混淆矩阵: confusion_matrix
# ================================================================= #
# 参考链接:https://blog.csdn.net/qq_36264495/article/details/88074246
# 样式设计:https://blog.csdn.net/ztf312/article/details/102474190
# ================================================================= #
print("#==========================2.计算混淆矩阵=======================================#")
matrix = confusion_matrix(y_true, y_fit)
plt.matshow(matrix, cmap='YlOrRd')
plt.title("Confusion_Matrix")
plt.show()
print("混淆矩阵:\n", matrix)
# 3.计算ROC曲线和AUC值
# ================================================================= #
# roc_curve(真实标签, 对应类别的概率, pos_label=指定正例样本类别)
# metrics.auc(假正类率, 真正类率)
# 每个类别可以绘制一条ROC曲线求一个AUC值, 把当前类别指定为正例即可
# =============================参考博客==================================== #
# https://blog.csdn.net/sun91019718/article/details/101314545
# https://blog.csdn.net/yuxiaosmd/article/details/83046162
# https://blog.csdn.net/akadiao/article/details/78788864
# roc:https://blog.csdn.net/taotiezhengfeng/article/details/80456110
# ================================================================= #
print("#==========================3.ROC曲线_AUC值=======================================#")
fpr, tpr, threshold = roc_curve(y_true, y_prob, pos_label=0)
auc = metrics.auc(fpr, tpr)
print("auc:", auc)
# 绘制roc曲线
plt.plot(fpr, tpr)
plt.title("class=0")
plt.xlabel("FPR")
plt.ylabel("TPR")
plt.show()
【我的笔记】