# coding=utf-8
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, average_precision_score,precision_score,f1_score,recall_score
save_flg = True
# confusion = confusion_matrix(y_test, y_pred)
confusion = np.array([[90, 3, 3, 4],
[3, 89, 3, 5],
[3, 5, 87, 5],
[2, 3, 10, 85]])
y_true = np.array([-1]*100 + [0]*100 + [1]*100+[2]*100)
print(y_true)
y_pred = np.array([-1]*90+ [0]*3 + [1]*3 +[2]*4+
[-1]*3 + [0]*89 + [1]*3 +[2]*5+
[-1]*3 + [0]*5 + [1]*87 +[2]*5+
[-1]*2 + [0] * 3 + [1] * 10+ [2] * 85)
plt.figure(figsize=(4, 4)) # 设置图片大小
# 1.热度图,后面是指定的颜色块,cmap可设置其他的不同颜色
plt.imshow(confusion, cmap=plt.cm.Blues)
plt.colorbar() # 右边的colorbar
# 2.设置坐标轴显示列表
indices = range(len(confusion))
# classes = ['A', 'B', 'C', 'D', 'E', 'F']
classes =['0','1','2','3']
# 第一个是迭代对象,表示坐标的显示顺序,第二个参数是坐标轴显示列表
plt.xticks(indices, classes, rotation=45) # 设置横坐标方向,rotation=45为45度倾斜
plt.yticks(indices, classes)
# 3.设置全局字体
# 在本例中,坐标轴刻度和图例均用新罗马字体['TimesNewRoman']来表示
# ['SimSun']宋体;['SimHei']黑体,有很多自己都可以设置
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 4.设置坐标轴标题、字体
# plt.ylabel('True label')
# plt.xlabel('Predicted label')
# plt.title('Confusion matrix')
plt.xlabel('预测值')
plt.ylabel('真实值')
plt.title('LightGBM_混淆矩阵', fontsize=12, fontfamily="SimHei") # 可设置标题大小、字体
# 5.显示数据
normalize = False
fmt = '.2f' if normalize else 'd'
thresh = confusion.max() / 2.
for i in range(len(confusion)): # 第几行
for j in range(len(confusion[i])): # 第几列
plt.text(j, i, format(confusion[i][j], fmt),
fontsize=16, # 矩阵字体大小
horizontalalignment="center", # 水平居中。
verticalalignment="center", # 垂直居中。
color="white" if confusion[i, j] > thresh else "black")
# 6.保存图片
# if save_flg:
# plt.savefig("./picture/confusion_matrix.png")
# 7.显示
plt.show()
print('------Weighted------')
print('Weighted recall', recall_score(y_true, y_pred, average='weighted'))
print('Weighted precision', precision_score(y_true, y_pred, average='weighted'))
print('Weighted f1-score', f1_score(y_true, y_pred, average='weighted'))
print('Weighted accuracy', accuracy_score(y_true, y_pred))