classification_report加入topk计算

参考:https://blog.csdn.net/dipizhong7224/article/details/104579159
官方文档:https://github.com/scikit-learn/scikit-learn/blob/7f9bad99d6e0a3e8ddf92a7e5561245224dab102/sklearn/metrics/_classification.py#L1551

def classification_report_topk(y_true, y_pred, topk=1, labelnames=None, digits=2, output_dict=False,):
    '''
    y_true: [1,1,2,3]
    y_pred: [[1,3],[3,2],[2,3],[1,2]]
    labelnames: [1,2,3]
    '''
    assert topk <= len(y_pred[0]), 'topk out of bounds!'
    if labelnames==None:
        from sklearn.utils.multiclass import unique_labels
        if type(y_pred)==list:
            labelnames = unique_labels(y_true, sum(y_pred,[]))
        elif type(y_pred)==numpy.ndarray:
            labelnames = unique_labels(y_true, y_pred.flatten())
        else:
            labelnames = unique_labels(y_true, y_true)
    rows = []
    tp_sums = 0
    y_pred=[each[0:topk] for each in y_pred]
    for label in labelnames:
        cur_res=[]
        tp_fn=y_true.count(label)#TP+FN
        #TP+FP
        tp_fp=0
        for i in y_pred:
            if label in i:
                tp_fp+=1
        #TP
        # 计算acc时需要使用tp
        tp=0
        for i in range(len(y_true)):
            if y_true[i] == label and label in y_pred[i]:
                tp+=1
        tp_sums+=tp
        support=tp_fn
        try:
            precision=tp/tp_fp
            recall=tp/tp_fn
            f1_score=2/((1/precision)+(1/recall))
        except ZeroDivisionError:
            precision=0.0
            recall=0.0
            f1_score=0.0
        rows.append([str(label),precision,recall,f1_score, support])

    accuracy_topk = tp_sums / len(y_true)
    rows.append(['accuracy', accuracy_topk, accuracy_topk, accuracy_topk, len(y_true)])
    
    average_options = ["macro", "weighted"]
    
    weights_weighted = [rows[i][4] for i in range(len(rows)-1)]
    weights_options = [None, weights_weighted]
    precision = [row[1] for row in rows[:-1]]
    recall = [row[2] for row in rows[:-1]]
    f1_score = [row[3] for row in rows[:-1]]
    for avg_name, weight in zip(average_options,weights_options):
        p = np.average(precision,weights=weight)
        r = np.average(recall,weights=weight)
        f1 = np.average(f1_score,weights=weight)
        rows.append([avg_name+' avg',p,r,f1,len(y_true)])
    
    # print format
    headers = ["precision", "recall", "f1-score", "support"]
    if output_dict:
        report_dict = {label[0]: label[1:] for label in rows}
        for label, scores in report_dict.items():
            report_dict[label] = dict(zip(headers, [float(i) for i in scores]))
        return report_dict
    else:
        target_names = [rows[i][0] for i in range(len(rows))]
        longest_last_line_heading = "weighted avg"
        name_width = max(len(cn) for cn in target_names)
        width = max(name_width, len(longest_last_line_heading), digits)
        head_fmt = "{:>{width}s} " + " {:>9}" * len(headers)
        report = head_fmt.format("", *headers, width=width)
        report += "\n\n"
        row_fmt = "{:>{width}s} " + " {:>9.{digits}f}" * 3 + " {:>9}\n"
        for row in rows:
            report += row_fmt.format(*row, width=width, digits=digits)
        report += "\n"
        return report
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值