混淆矩阵、准确率、F1和召回率的具体实现及混淆矩阵的可视化

        utils专栏不会细讲概念性的内容,偏向实际使用,如有问题,欢迎留言。如果对你有帮助就点个赞哈,也不搞什么粉丝可见有的没的,有帮助点个赞就ok

1、混淆矩阵、准确率、F1和召回率的计算

混淆矩阵    对于混淆矩阵的计算,这个我们直接从sklearn.metrics导入confusion_matrix计算,只需要向其中传递两个参数,一个是y_true,一个是y_pred,就可以直接得到混淆矩阵了:

from sklearn.metrics import confusion_matrix
confMatrix = confusion_matrix(label, pre)

        这个是随便拿了个数据集,加载了预训练参数,跑了1轮的混淆矩阵: 

          准确率、F1和召回率的计算我们直接使用混淆矩阵来计算,混淆矩阵可以帮助我们很好的获得以下每类的数目:

  • TP(True Positive):正确的正例,一个实例是正类并且也被判定成正类
  • FN(False Negative):错误的反例,漏报,本为正类但判定为假类
  • FP(False Positive):错误的正例,误报,本为假类但判定为正类
  • TN(True Negative):正确的反例,一个实例是假类并且也被判定成假类

准确率:

Acc=\frac{TP+TN}{ALL}

精确率:

Pre=\frac{TP}{TP+FP}

召回率:

Recall=\frac{TP}{TP+FN}

F1:

F1_score=\frac{2 \times pre \times recall}{pre + recall}

        相关的理论部分这里不过介绍,直接上代码实现,最后返回的是总的精确率、总的召回率、总的F1_score、一个图表可视化和一个几何平均,可以根据自己的需要来调整代码,比如只需要某一类的召回率等等:

        注:使用图标可视化精确率、召回率和F1时,请先:

import prettytable
def calculate_prediction_recall(label, pre, classes=None):
    """
    计算准确率和召回率:传入预测值及对应的真实标签计算
    :param label:标签
    :param pre:对应的预测值
    :param classes:类别名(None则为数字代替)
    :return:
    """
    if classes:
        classes = list(range(classes))

    # print(classes)
    confMatrix = confusion_matrix(label, pre)
    print(confMatrix)
    total_prediction = 0
    total_recall = 0
    result_table = prettytable.PrettyTable()
    class_multi = 1
    result_table.field_names = ['Type', 'Prediction(精确率)', 'Recall(召回率)', 'F1_Score']
    for i in range(len(confMatrix)):
        label_total_sum_col = confMatrix.sum(axis=0)[i]
        label_total_sum_row = confMatrix.sum(axis=1)[i]
        if label_total_sum_col:     # 防止除0
            prediction = confMatrix[i][i] / label_total_sum_col
        else:
            prediction = 0
        if label_total_sum_row:
            recall = confMatrix[i][i] / label_total_sum_row
        else:
            recall = 0
        if (prediction + recall) != 0:
            F1_score = prediction * recall * 2 / (prediction + recall)
        else:
            F1_score = 0
        result_table.add_row([classes[i], np.round(prediction, 3), np.round(recall, 3),
                              np.round(F1_score, 3)])

        total_prediction += prediction
        total_recall += recall
        class_multi *= prediction
    total_prediction = total_prediction / len(confMatrix)
    total_recall = total_recall / len(confMatrix)
    total_F1_score = total_prediction * total_recall * 2 / (total_prediction + total_recall)
    geometric_mean = pow(class_multi, 1 / len(confMatrix))

    return total_prediction, total_recall, total_F1_score, result_table, geometric_mean, confMatrix

        图标可视化的一个结果展示(注意:图标显示的是每一类的准确率、召回率和F1,函数返回的是总的准确率、召回率和F1,可以根据自己的需要进行修改代码):

2、混淆矩阵的可视化

        没什么好说的,直接上代码:

def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """
    混淆矩阵的可视化: 传入混淆矩阵和类别名(或数字代替)
    :param cm: 混淆矩阵
    :param classes: 类别
    :param normalize:
    :param title:
    :param cmap:
    :return:
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig('runs/picture/confMatrix.jpg')
    plt.show()

        结果展示:

评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值