多分类任务混淆矩阵(python代码实现)

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np

def plot_confusion_matrix(cm, result_path, title='Confusion Matrix'):

    plt.figure(figsize=(4, 4), dpi=300)
    np.set_printoptions(precision=2)

    # 在混淆矩阵中每格的概率值
    ind_array = np.arange(len(classes))
    x, y = np.meshgrid(ind_array, ind_array)
    for x_val, y_val in zip(x.flatten(), y.flatten()):
        c = cm[y_val][x_val]
        plt.text(x_val, y_val, "%0.2f" % (c,), color="white"  if c > cm.max()/2 else "black", fontsize=10, va='center', ha='center')
    
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar()
    xlocations = np.array(range(len(classes)))
    plt.xticks(xlocations, classes)
    plt.yticks(xlocations, classes)
    plt.ylabel('Ground trurh')
    plt.xlabel('Predict')
    
    # offset the tick
    tick_marks = np.array(range(len(classes))) + 0.5
    plt.gca().set_xticks(tick_marks, minor=True)
    plt.gca().set_yticks(tick_marks, minor=True)
    plt.gca().xaxis.set_ticks_position('none')
    plt.gca().yaxis.set_ticks_position('none')
    plt.grid(True, which='minor', color="gray", linestyle='-')
    plt.gcf().subplots_adjust(bottom=0.05)

  
    # show confusion matrix
    plt.savefig(result_path[:-4]+'.png', format='png')
    plt.show()

classes = ['M0', 'M1', 'M2']

random_numbers = np.random.randint(3, size=50)  # 6个类别,随机生成50个样本
y_true = random_numbers.copy()  # 样本实际标签
random_numbers[:10] = np.random.randint(3, size=10)  # 将前10个样本的值进行随机更改
y_pred = random_numbers  # 样本预测标签


result_paths=['DL_train.csv', 'DLC_train.csv','DL_test.csv', 'DLC_test.csv']

for result_path in result_paths:
    with open(result_path, 'r') as f:
        result_list = f.read()
    
    result_list = result_list.split('\n')[1:-1]
    result_list = [result.split(',') for result in result_list]
        
    
    id_list = [int(result[0]) for result in result_list]
    y = np.array([float(result[1]) for result in result_list])
    p = np.array([float(result[2]) for result in result_list])
    
    p[p<0.5]=0
    p[(p>0.5)*(p<1.5)]=1
    p[p>1.5]=2


    cm = confusion_matrix(y, p)
    plot_confusion_matrix(cm, result_path, title='Confusion matrix',)
    
    print(result_path, (cm[0,0]+cm[1,1]+cm[2,2])/cm.sum())

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值