2021SC@SDUSC-山东大学软件工程与实践-Senta(九)

本期对SENTA项目中metrics的sklearn_metrics进行源码分析,主要为分类统计指标的实现。在这里插入图片描述


    def evaluate(output, evaluate_types, average="macro"):
        """
        :param output:
        :param evaluate_types:
        :param average:
        :return:
        """
        if not evaluate_types:
            return
        evas = str.split(evaluate_types, ",")
        if len(evas) == 0:
            return
        ret = {}
        for eva in evas:
            if eva == "acc":
                ret[eva] = SKLearnClassify.sk_learn_acc(output)
            elif eva == "auc":
                ret[eva] = SKLearnClassify.sk_learn_auc(output)
            elif eva == "f1":
                ret[eva] = SKLearnClassify.sk_learn_f1(output, average=average)
            elif eva == "precision":
                ret[eva] = SKLearnClassify.sk_learn_precision_score(output, average=average)
            elif eva == "recall":
                ret[eva] = SKLearnClassify.sk_learn_recall_score(output, average=average)

        return ret

构建evaluate函数,依据输入类别计算 acc(准确率 accuracy)auc(AUC值为ROC曲线所覆盖的区域面积),f1值,准确度和召回率。

    def sk_learn_acc(output):
        """
        :param output:
        :return:
        """
        predict = output["classify_infer"]
        label = output["label"]
        predict_arr = None
        label_arr = None

        if isinstance(predict, list):
            predict_arr = np.array(predict).astype('int64')
        else:
            tmp_arr = []
            for pre in predict:
                tmp_arr.append(np.argmax(pre))
            predict_arr = np.array(tmp_arr)

        if isinstance(label, list):
            label_arr = np.array(label).astype('int64')
        else:
            label_arr = np.array(label.flatten())

        score = accuracy_score(label_arr, predict_arr)
        # logging.info("sklearn acc score = ", score)
        return score

计算ACC的实现。
在这里插入图片描述

    def sk_learn_auc(output):
        """
        :param output:
        :return:
        """
        predict = output["classify_infer"]
        label = output["label"]
        assert len(predict[0]) == 2, "auc metrics only support binary classification, \
                                      and the positive label must be 1, negtive label must be 0"
        predict_arr = []
        for pre in predict:
            pos_prob = pre[1]
            predict_arr.append(pos_prob)

        # y = np.array([1, 1, 2, 2])
        # pred = np.array([0.1, 0.4, 0.35, 0.8])

        fpr, tpr, thresholds = metrics.roc_curve(np.array(label.flatten()), np.array(predict_arr))
        score = metrics.auc(fpr, tpr)
        # logging.info("sklearn auc score = ", score)
        return score

计算AUC的实现。
在这里插入图片描述

    def sk_learn_f1(output, average="macro"):
        """
        :param output:
        :param average:
        :return:
        """
        predict = output["classify_infer"]
        label = output["label"]

        predict_arr = None
        label_arr = None

        if isinstance(predict, list):
            predict_arr = np.array(predict).astype('int64')
        else:
            tmp_arr = []
            for pre in predict:
                tmp_arr.append(np.argmax(pre))
            predict_arr = np.array(tmp_arr)

        if isinstance(label, list):
            label_arr = np.array(label).astype('int64')
        else:
            label_arr = np.array(label.flatten())

        score = f1_score(label_arr, predict_arr, average=average)
        # logging.info("sklearn f1 macro score = ", score)
        return score

计算f1度量值的实现。
在这里插入图片描述

    def sk_learn_precision_score(output, average="macro"):
        """
        :param output:
        :param average:
        :return:
        """
        predict = output["classify_infer"]
        label = output["label"]
        predict_arr = None
        label_arr = None

        if isinstance(predict, list):
            predict_arr = np.array(predict).astype('int64')
        else:
            tmp_arr = []
            for pre in predict:
                tmp_arr.append(np.argmax(pre))
            predict_arr = np.array(tmp_arr)

        if isinstance(label, list):
            label_arr = np.array(label).astype('int64')
        else:
            label_arr = np.array(label.flatten())

        score = precision_score(label_arr, predict_arr, average=average)
        # logging.info("sklearn precision macro score = ", score)
        return score

计算精确度的实现。
在这里插入图片描述

    def sk_learn_recall_score(output, average="macro"):
        """
        :param output:
        :param average:
        :return:
        """
        predict = output["classify_infer"]
        label = output["label"]
        predict_arr = None
        label_arr = None

        if isinstance(predict, list):
            predict_arr = np.array(predict).astype('int64')
        else:
            tmp_arr = []
            for pre in predict:
                tmp_arr.append(np.argmax(pre))
            predict_arr = np.array(tmp_arr)

        if isinstance(label, list):
            label_arr = np.array(label).astype('int64')
        else:
            label_arr = np.array(label.flatten())

        score = recall_score(label_arr, predict_arr, average=average)
        # logging.info("sklearn recall macro score = ", score)
        return score

计算召回率的实现。
在这里插入图片描述
本期源码分析到此结束,谢谢。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值