多分类模型评价指标、recall、precision、f1、accuracy的计算,附代码

# coding=utf-8

import matplotlib.pyplot as plt

import numpy as np

from sklearn.metrics import confusion_matrix

from sklearn.metrics import accuracy_score, average_precision_score,precision_score,f1_score,recall_score

save_flg = True

# confusion = confusion_matrix(y_test, y_pred)

confusion = np.array([[90, 3, 3, 4],

                      [3, 89, 3, 5],

                      [3, 5, 87, 5],

                      [2, 3, 10, 85]])


 

y_true = np.array([-1]*100 + [0]*100 + [1]*100+[2]*100)

print(y_true)

y_pred = np.array([-1]*90+ [0]*3 + [1]*3 +[2]*4+

                  [-1]*3 + [0]*89 + [1]*3 +[2]*5+

                  [-1]*3 + [0]*5 + [1]*87 +[2]*5+

                  [-1]*2 + [0] * 3 + [1] * 10+ [2] * 85)

plt.figure(figsize=(4, 4))  # 设置图片大小

# 1.热度图,后面是指定的颜色块,cmap可设置其他的不同颜色

plt.imshow(confusion, cmap=plt.cm.Blues)

plt.colorbar()  # 右边的colorbar

# 2.设置坐标轴显示列表

indices = range(len(confusion))

# classes = ['A', 'B', 'C', 'D', 'E', 'F']

classes =['0','1','2','3']

# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表

plt.xticks(indices, classes, rotation=45)  # 设置横坐标方向,rotation=45为45度倾斜

plt.yticks(indices, classes)

# 3.设置全局字体

# 在本例中,坐标轴刻度和图例均用新罗马字体['TimesNewRoman']来表示

# ['SimSun']宋体;['SimHei']黑体,有很多自己都可以设置

plt.rcParams['font.sans-serif'] = ['SimHei']

plt.rcParams['axes.unicode_minus'] = False

# 4.设置坐标轴标题、字体

# plt.ylabel('True label')

# plt.xlabel('Predicted label')

# plt.title('Confusion matrix')

plt.xlabel('预测值')

plt.ylabel('真实值')

plt.title('LightGBM_混淆矩阵', fontsize=12, fontfamily="SimHei")  # 可设置标题大小、字体

# 5.显示数据

normalize = False

fmt = '.2f' if normalize else 'd'

thresh = confusion.max() / 2.

for i in range(len(confusion)):  # 第几行

    for j in range(len(confusion[i])):  # 第几列

        plt.text(j, i, format(confusion[i][j], fmt),

                 fontsize=16,  # 矩阵字体大小

                 horizontalalignment="center",  # 水平居中。

                 verticalalignment="center",  # 垂直居中。

                 color="white" if confusion[i, j] > thresh else "black")

# 6.保存图片

# if save_flg:

    # plt.savefig("./picture/confusion_matrix.png")

# 7.显示

plt.show()



 

print('------Weighted------')

print('Weighted recall', recall_score(y_true, y_pred, average='weighted'))

print('Weighted precision', precision_score(y_true, y_pred, average='weighted'))

print('Weighted f1-score', f1_score(y_true, y_pred, average='weighted'))

print('Weighted accuracy', accuracy_score(y_true, y_pred))

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值