from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
def plot_confusion_matrix(cm, result_path, title='Confusion Matrix'):
plt.figure(figsize=(4, 4), dpi=300)
np.set_printoptions(precision=2)
# 在混淆矩阵中每格的概率值
ind_array = np.arange(len(classes))
x, y = np.meshgrid(ind_array, ind_array)
for x_val, y_val in zip(x.flatten(), y.flatten()):
c = cm[y_val][x_val]
plt.text(x_val, y_val, "%0.2f" % (c,), color="white" if c > cm.max()/2 else "black", fontsize=10, va='center', ha='center')
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title(title)
plt.colorbar()
xlocations = np.array(range(len(classes)))
plt.xticks(xlocations, classes)
plt.yticks(xlocations, classes)
plt.ylabel('Ground trurh')
plt.xlabel('Predict')
# offset the tick
tick_marks = np.array(range(len(classes))) + 0.5
plt.gca().set_xticks(tick_marks, minor=True)
plt.gca().set_yticks(tick_marks, minor=True)
plt.gca().xaxis.set_ticks_position('none')
plt.gca().yaxis.set_ticks_position('none')
plt.grid(True, which='minor', color="gray", linestyle='-')
plt.gcf().subplots_adjust(bottom=0.05)
# show confusion matrix
plt.savefig(result_path[:-4]+'.png', format='png')
plt.show()
classes = ['M0', 'M1', 'M2']
random_numbers = np.random.randint(3, size=50) # 6个类别,随机生成50个样本
y_true = random_numbers.copy() # 样本实际标签
random_numbers[:10] = np.random.randint(3, size=10) # 将前10个样本的值进行随机更改
y_pred = random_numbers # 样本预测标签
result_paths=['DL_train.csv', 'DLC_train.csv','DL_test.csv', 'DLC_test.csv']
for result_path in result_paths:
with open(result_path, 'r') as f:
result_list = f.read()
result_list = result_list.split('\n')[1:-1]
result_list = [result.split(',') for result in result_list]
id_list = [int(result[0]) for result in result_list]
y = np.array([float(result[1]) for result in result_list])
p = np.array([float(result[2]) for result in result_list])
p[p<0.5]=0
p[(p>0.5)*(p<1.5)]=1
p[p>1.5]=2
cm = confusion_matrix(y, p)
plot_confusion_matrix(cm, result_path, title='Confusion matrix',)
print(result_path, (cm[0,0]+cm[1,1]+cm[2,2])/cm.sum())