from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import torch
def plot_confusion_matrix(cm,savename,classes,title='Confusion Matrix',normalize=True):
plt.figure(figsize=(12,8),dpi=100)
np.set_printoptions(precision=2)
# 在混淆矩阵中每格的概率值
ind_array = np.arange(len(classes))
x,y = np.meshgrid(ind_array,ind_array)
if normalize:
#显示百分比
cm = np.array(cm,dtype=float)/np.sum(cm,axis=1)
for x_val,y_val in zip(x.flatten(),y.flatten()):
c = cm[y_val][x_val]
# c = cm[y_val][x_val]/np.sum(cm[y_val])
if c>0.001:
plt.text(x_val, y_val, "%0.2f" % (c,), color='black', fontsize=10, va='center', ha='center')
else:
#显示数量
for x_val,y_val in zip(x.flatten(),y.flatten()):
c = cm[y_val][x_val]
# c = cm[y_val][x_val]/np.sum(cm[y_val])
if c>0.001:
plt.text(x_val, y_val, "%0.2f" % (c,), color='black', fontsize=10, va='center', ha='center')
# plt.imshow(cm,interpolation='nearest',cmap=plt.cm.binary)
plt.imshow(cm,interpolation='nearest',cmap=plt.cm.Blues)
plt.title(title)
plt.colorbar()
xlocations = np.array(range(len(classes)))
plt.xticks(xlocations, classes, rotation=90)
plt.yticks(xlocations, classes)
plt.ylabel('Actual label')
plt.xlabel('Predict label')
#offset the tick
# offset the tick
tick_marks = np.array(range(len(classes))) + 0.5
plt.gca().set_xticks(tick_marks, minor=True)
plt.gca().set_yticks(tick_marks, minor=True)
plt.gca().xaxis.set_ticks_position('none')
plt.gca().yaxis.set_ticks_position('none')
plt.grid(True, which='minor', linestyle='-')
plt.gcf().subplots_adjust(bottom=0.15)
#show confusion matrix
# plt.savefig(savename, format='png')
plt.show()
if __name__ == '__main__':
# classes表示不同类别的名称,比如这有6个类别
path="..\\data\\confusion_matrix.npy"
contents = np.load(path,allow_pickle=True)
contents = contents.tolist();
y_true = contents["y_true"]
y_pred= contents["y_pre"]
# y_true = torch.cat(y_true).detach().numpy()
# y_pre = torch.cat(y_pre).detach().numpy()
classes = ['A', 'B', 'C', 'D', 'E', 'F','G','H','I','J']
# random_numbers = np.random.randint(6, size=50) # 6个类别,随机生成50个样本
# y_true = random_numbers.copy() # 样本实际标签
# random_numbers[:10] = np.random.randint(6, size=10) # 将前10个样本的值进行随机更改
# y_pred = random_numbers # 样本预测标签
#获取混淆矩阵
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm,'confusion_matrix.png',classes, title='confusion matrix')
python 混淆矩阵的画法
最新推荐文章于 2024-02-28 18:36:44 发布