PR计算逻辑

# -*- coding: UTF-8 -*-

import argparse
import os
import cv2
import time
import json
import shutil

import numpy as np
from collections import defaultdict
from ppdet.core.workspace import load_config

COLORS = {'tp': [0, 255, 0], 'fp': [0, 0, 255], 'fn': [255, 0, 0]}

def parse_args():
    """参数解析
    """
    parser = argparse.ArgumentParser(description="you should add those parameter")          
    parser.add_argument('--config_file', type=str, default="", help="paddle train config yaml.")
    parser.add_argument('--gt_file', type=str, default="", help="ground truth file path.")
    parser.add_argument('--pred_file', type=str, default="", help="pred file path.")
    parser.add_argument('--save_badcase', action="store_true", help="save badcase images")
    parser.add_argument('--save_badcase_small', action="store_true", help="save small badcase images")
    parser.add_argument('--tasks', type=str, default="all", help="tasks need to evaluate")

    args = parser.parse_args()
    return args

def parse_test_config(config, tasks):
    """解析评测集标签的各任务属性信息:attributes,以及准备评测的任务tasks。
    
    Args:
        config: 模型训练参数文件;
        tasks:用户指定的要进行评测的任务;
    
    Return:
        attributes: dict:{attr_name: attr_info},用于描述每个任务属性的基本信息(任务索引,评测metric)
        eval_tasks: 根据用户指定的tasks和attributes解析得到的要评测的任务
    """
    attributes = {}
    for attr in config['attributes']:
        attr_name = list(attr)[0]
        attr_info = attr[attr_name]
        print("attr_name:", attr_name)
        print("attr_info", attr_info)
        if isinstance(attr_info['anno_key'], list):
            for i, anno_key in enumerate(attr_info['anno_key']):
                attributes[anno_key] = {'anno_key': anno_key}
                attributes[anno_key]['num_classes'] = attr_info['num_classes'][i]
                attributes[anno_key]['metric'] = attr_info['metric']
                if 'img' in attr_name:
                    attributes[anno_key]['anno_level'] = 'img'
                else:
                    attributes[anno_key]['anno_level'] = 'obj'
        else:
            attributes.update(attr)
    
    eval_tasks = set()
    if tasks == 'all':  # 如果指定的评测任务是all,则默认attributes里的任务都参评
        for attr_name, attr in attributes.items():
            eval_tasks.add(attr_name)
    else:
        tasks = tasks.split(',')
        for task in tasks:
            assert task in attributes, "task: {} not in attributes!".format(task)
            eval_tasks.add(task)
            if task == 'gt_bbox':
                eval_tasks.add('gt_class')

    return attributes, eval_tasks

def parse_gt_file(gt_file, gt_attrs, eval_tasks, merge_class):
    """GT 文件解析"""
    if gt_file.endswith('.json'):
        gts = parse_json_gtfile(gt_file, gt_attrs, eval_tasks, merge_class)
    else:
        gts = parse_txt_gtfile(gt_file, gt_attrs, eval_tasks, merge_class)

    return gts

def parse_json_gtfile(gt_file, gt_attrs, eval_tasks, merge_class):
    """解析COCO格式gt file"""
    from pycocotools.coco import COCO
    coco = COCO(gt_file)
    
    gts = defaultdict(dict)
    if 'gt_bbox' in eval_tasks:
        gts['boxes'] = defaultdict(list)
        gts['box_clses'] = defaultdict(dict)

    img_ids = coco.getImgIds()
    for img_id in img_ids:
        img_anno = coco.loadImgs([img_id])[0]
        img_name = img_anno['file_name'].split('/')[-1]
        im_w = float(img_anno['width'])
        im_h = float(img_anno['height'])
        
        ins_anno_ids = coco.getAnnIds(imgIds=[img_id], iscrowd=False)
        instances = coco.loadAnns(ins_anno_ids)

        obj_tasks = []
        for task in eval_tasks:
            attr_info = gt_attrs[task]
            anno_level = attr_info['anno_level']
            if anno_level == 'obj':
                obj_tasks.append(task)
                continue
            anno_key = attr_info['anno_key']
            gts[img_name][task] = int(img_anno[anno_key]) + 1

        if len(obj_tasks) != 0:
            for instance in instances:
                one_box = {}
                for task in obj_tasks:
                    attr_info = gt_attrs[task]
                    anno_key = attr_info['anno_key']
                    if task == 'gt_bbox' and 'bbox' in instance:
                        x, y, box_w, box_h = instance['bbox']
                        x1 = max(0, x)
                        y1 = max(0, y)
                        x2 = min(im_w - 1, x + max(0, box_w - 1))
                        y2 = min(im_h - 1, y + max(0, box_h - 1))
                        one_box['box'] = [x1, y1, x2, y2]
                    elif task == 'gt_class':
                        cls_id = instance['category_id']    # 从0开始的
                        if merge_class is not None:
                            merged = False
                            for merge_id, ids in merge_class.items():
                                if cls_id in ids:
                                    cls_id = merge_id
                                    merged = True
                                    break
                            assert merged, "cls_id: {} not in merge_class".format(cls_id)
                        one_box['cls_id'] = cls_id
                    else:
                        one_box[anno_key] = instance[anno_key]
                if len(one_box['box']) != 0:
                    # TODO json中如果存在ignore字段,更新该ignore值
                    ignore = 0
                    gts['boxes'][cls_id].append([one_box['box'][0], one_box['box'][1], one_box['box'][2], one_box['box'][3], ignore, img_name])
                    if img_name not in gts['box_clses'][one_box['cls_id']]:
                        gts['box_clses'][one_box['cls_id']][img_name] = []
                    gts['box_clses'][one_box['cls_id']][img_name].append({
                        'bbox': [one_box['box'][0], one_box['box'][1], one_box['box'][2], one_box['box'][3]],
                        'matched': 0,
                        'ignore': ignore
                    })
    return gts

def parse_txt_gtfile(gt_file, gt_attrs, eval_tasks, merge_class):
    """解析txt格式的标签文件
    Args:
        gt_file:标签文件
        gt_attrs:描述标签文件格式的attributes对象,根据config解析得到
        eval_tasks:要评测的任务
        merge_class: 根据配置合并检测框类别, 例:merge_class={1: [1, 2, 3]}, 将cls_id为[1, 2, 3]的值重置为1
    """
    gts = defaultdict(dict)
    if 'gt_bbox' in eval_tasks:
        gts['boxes'] = defaultdict(list)
        gts['box_clses'] = defaultdict(dict)
    with open(gt_file, 'r') as f:
        for line in f.readlines():
            words = line.strip().split()
            img_name = words[0]
            box = []
            for task in eval_tasks:
                ignore = 0
                attr_info = gt_attrs[task]
                anno_key = attr_info['anno_key']
                # 当标签格式是txt时,anno_key的值含义应该是gt_label的索引,类型为整型,其他类型均过滤; 索引值大于标签长度也过滤。
                if isinstance(anno_key, str) or anno_key >= len(words):
                    continue
                if task == 'gt_bbox':
                    x1, y1, x2, y2 = map(float, words[anno_key: anno_key + 4])
                    box = [x1, y1, x2, y2]
                elif task == 'gt_class':
                    cls_id = int(words[anno_key])
                    if merge_class is not None:
                        merged = False
                        for merge_id, ids in merge_class.items():
                            if cls_id in ids:
                                cls_id = merge_id
                                merged = True
                                break
                        assert merged, "cls_id: {} not in merge_class".format(cls_id)
                elif task == 'ignore':
                    ignore = int(words[anno_key])
                else:
                    gts[img_name][task] = int(words[anno_key]) + 1
            if len(box) != 0:
                gts['boxes'][cls_id].append([box[0], box[1], box[2], box[3], 0, ignore, img_name])

                if img_name not in gts['box_clses'][cls_id]:
                    gts['box_clses'][cls_id][img_name] = []
                # gts['box_clses'][cls_id][img_name].append([box[0], box[1], box[2], box[3], 0, 0])
                gts['box_clses'][cls_id][img_name].append({
                    'bbox': [box[0], box[1], box[2], box[3]],
                    'matched': 0,
                    'ignore': ignore
                })
                if ignore == 1:
                    print(gts['box_clses'][cls_id][img_name])

    return gts

def parse_pred_file(pred_file, eval_tasks, threshold=0., merge_class=None):
    """解析标准化输出结果文件
    Args:
        pred_file:预测结果文件
        eval_tasks:要评测的任务
        threshold: 检测框置信度阈值
        merge_class: 根据配置合并检测框类别, 例:merge_class={1: [1, 2, 3]}, 将cls_id为[1, 2, 3]的值重置为1
    """
    preds = defaultdict(dict)
    if 'gt_bbox' in eval_tasks:
        preds['boxes'] = defaultdict(list)
        preds['box_clses'] = defaultdict(dict)
    with open(pred_file, 'r') as f:
        for line in f.readlines():
            img_name, result_json = line.strip().split()
            result_dict = json.loads(result_json)
            # 解析每个分类和属性预测结果
            if 'attrs' in result_dict:
                for attr in result_dict['attrs']:
                    task = attr['id'].replace('_out', '')
                    preds[img_name][task] = attr['typeId']

            # 解析每个检测框结果
            if 'boxs' in result_dict:
                for box in result_dict['boxs']:
                    if 'attrs' not in box:
                        # 解析标准化框架推理得到的result文件
                        cls_id = box['typeId']
                        score = box['confidence']
                        x1, y1, x2, y2 = box['box']['topLeft']['x'], box['box']['topLeft']['y'], box['box']['bottomRight']['x'], box['box']['bottomRight']['y']
                    else:
                        # 解析sdk推理得到的result文件
                        for attr in box['attrs']:
                            if '_cls' in attr['id']:
                                cls_id = attr['typeId']
                            if '_conf' in attr['id']:
                                score = attr['confidence']
                        x1, y1, x2, y2 = box['box']['topLeft']['x'], box['box']['topLeft']['y'], box['box']['buttonRight']['x'], box['box']['buttonRight']['y']
                    if score < threshold:
                        continue
                    if merge_class is not None:
                        merged = False
                        for merge_id, ids in merge_class.items():
                            if cls_id in ids:
                                cls_id = merge_id
                                merged = True
                                break
                        assert merged, "cls_id: {} not in merge_class".format(cls_id)    
                    preds['boxes'][cls_id].append([x1, y1, x2, y2, score, 0, img_name])
                    if img_name not in preds['box_clses'][cls_id]:
                        preds['box_clses'][cls_id][img_name] = []
                    # preds['box_clses'][cls_id][img_name].append([x1, y1, x2, y2, score, 0])
                    preds['box_clses'][cls_id][img_name].append({
                        'bbox': [x1, y1, x2, y2],
                        'score': score
                    })
    return preds

def get_cls_pr(num_classes, gts, preds, task, badcase_param):
    total = 0
    correct = 0
    conf_mat = np.zeros((num_classes + 1, num_classes + 1), dtype=np.int32)

    for img_name in gts:
        if img_name not in preds:
            continue
        gt = gts[img_name][task]
        pred = preds[img_name][task]
        total += 1
        conf_mat[pred, gt] += 1
        if gt == pred:
            correct += 1
        else:
            if badcase_param['save_badcase']:
                new_image_root = os.path.join(
                        badcase_param['save_root'], 
                        task,
                        'gt' + str(gt) + '_' + 'pred' + str(pred))
                if not os.path.exists(new_image_root):
                    os.makedirs(new_image_root)
                
                # 保存 badcase
                image_path = os.path.join(badcase_param['image_root'], img_name)
                new_image_path = os.path.join(new_image_root, img_name)
                os.system('cp {} {}'.format(image_path, new_image_path))
    
    acc = correct / total if total > 0 else 0
    print('Total Acc for {}: {:.4f}, total: {}'.format(task, acc, total))

    for cls_id in range(1, num_classes + 1):
        precision = conf_mat[cls_id, cls_id] / np.sum(conf_mat[cls_id, :]) if np.sum(conf_mat[cls_id, :]) > 0 else 0
        recall = conf_mat[cls_id, cls_id] / np.sum(conf_mat[:, cls_id]) if np.sum(conf_mat[:, cls_id]) > 0 else 0
        print('Class: {}\tPrecison: {:.4f}\tRecall: {:.4f}\tTP: {}\tFP: {}\tFN: {}\tTOTAL: {}'.format(
            cls_id, precision, recall, conf_mat[cls_id, cls_id], np.sum(conf_mat[cls_id, :]) - conf_mat[cls_id, cls_id], 
            np.sum(conf_mat[:, cls_id]) - conf_mat[cls_id, cls_id], np.sum(conf_mat[:, cls_id])
        ))

def compute_iou(bbox1, bbox2):
    xmin1, ymin1, xmax1, ymax1 = bbox1
    xmin2, ymin2, xmax2, ymax2 = bbox2
  
    # 获取矩形框交集对应的顶点坐标(intersection)
    xx1 = np.max([xmin1, xmin2])
    yy1 = np.max([ymin1, ymin2])
    xx2 = np.min([xmax1, xmax2])
    yy2 = np.min([ymax1, ymax2])
    # 计算交集面积 
    inter_area = (np.max([0, xx2 - xx1])) * (np.max([0, yy2 - yy1]))

    # 计算两个矩形框面积
    area1 = (xmax1 - xmin1 ) * (ymax1 - ymin1) 
    area2 = (xmax2 - xmin2) * (ymax2 - ymin2)

    # 计算交并比(交集/并集)
    union = (area1 + area2 - inter_area)
    if union <= 0:
        iou = 0
    else:
        iou = inter_area / (area1 + area2 - inter_area )  # 注意:这里inter_area不能乘以2,乘以2就相当于把交集部分挖空了
    return iou

def get_det_pr(gts, preds, iou_threshold):
    """
    根据给定的ground truths 和 predictions,计算预测框PR和MAP指标
    
    Args:
        gts(dict): ground truths, 包含'boxes', 'box_clses'字段, 键值为cls_id,
                其中boxes的值为[[x1, y1, x2, y2, score, 0, img_name], ...];
                box_clses的值为{'img_name': [{'bbox', 'matched', 'ignore'}, ...], ...}
        preds(dict): predictions, 包含'boxes'字段, 键值为cls_id, 格式和gts基本一致;
        iou_threshold: 评测时, 匹配框iou阈值

    Returns:
        tuple: 返回一个元组(ret, badcases);
            其中 ret 为一个列表,元素为一个字典,包含以下字段:
            class (int): 当前类的id。
            precision (ndarray): 存储precision的数组。
            recall (ndarray): 存储recall的数组。
            AP (float): 当前类平均准确率。
            total positives (int): 该类中正样本的总数。
            total TP (int): 该类中true positive 的总数。
            total FP (int): 该类中false positive 的总数。
            every_tp (ndarray): 每个预测框的是TP的标志符。
            every_fp (ndarray): 每个预测框的是FP的标志符。
            ious (ndarray): 所有预测框对应的IoU。
            scores (ndarray): 所有预测框对应的score。
            pred_boxes (List[ndarray]): 所有预测框的信息。
            gt_boxes (List[ndarray]): 所有真值框的信息。

            badcases是dict: 返回一个字典,键值为图像名,值为一个列表,每个元素为每个框信息[x1, y1, x2, y2, cls_id, score, badcase_type]
    """
    ret, classes = [], []
    for cls_id in gts['boxes']:
        if cls_id in preds['boxes']:
            classes.append(cls_id)
        else:
            print('class: {} not in preds!'.format(cls_id))

    badcases = defaultdict(list)
    for cls_id in classes:
        pred_boxes = preds['boxes'][cls_id]     # [[x1, y1, x2, y2, score, 0, img_name], ...]
        gt_boxes = gts['boxes'][cls_id]         # [[x1, y1, x2, y2, 0, 0, img_name], ...]
        gt_box_clses = gts['box_clses'][cls_id] # {'img_name': [{'bbox', 'matched', 'ignore'}, ...], ...}
        np_gt_boxes = np.array(gt_boxes)
        num_pos = np.sum(np_gt_boxes[:, -2] == '0')

        pred_boxes = sorted(pred_boxes, key=lambda conf: conf[4], reverse=True) # 按score从大到小排序
        TP, FP = [], []
        ious, scores = [], []
        for pred_box in pred_boxes: # 遍历预测框
            iouMax, jMax = 0, 0
            img_name = pred_box[-1]
            score = pred_box[4]
            x1, y1, x2, y2 = map(int, pred_box[: 4])
            if img_name in gt_box_clses:
                for j in range(len(gt_box_clses[img_name])): # 遍历真值框,找到iou值最大的框索引
                    iou = compute_iou(pred_box[:4], gt_box_clses[img_name][j]['bbox'])
                    if iou > iouMax:
                        iouMax = iou
                        jMax = j
                if iouMax > iou_threshold:
                    if gt_box_clses[img_name][jMax]['matched'] == 0: # 真值框没被匹配过
                        if gt_box_clses[img_name][jMax]['ignore'] == 0: # 且不是Ignore
                            TP.append(1)
                            FP.append(0)
                            ious.append(iouMax)
                            scores.append(score)
                            gt_box_clses[img_name][jMax]['matched'] = 1
                            badcases[img_name].append([x1, y1, x2, y2, cls_id, score, 'tp'])
                    else:
                        if gt_box_clses[img_name][jMax]['ignore'] == 0:
                            TP.append(0)
                            FP.append(1)
                            scores.append(score)
                            badcases[img_name].append([x1, y1, x2, y2, cls_id, score, 'fp'])
                else:
                    TP.append(0)
                    FP.append(1)
                    scores.append(score)
                    badcases[img_name].append([x1, y1, x2, y2, cls_id, score, 'fp'])
            else:
                TP.append(0)
                FP.append(1)
                scores.append(score)
                badcases[img_name].append([x1, y1, x2, y2, cls_id, score, 'fp'])

        for img_name in gt_box_clses:
            for j in range(len(gt_box_clses[img_name])):
                if gt_box_clses[img_name][j]['matched'] == 0 and gt_box_clses[img_name][j]['ignore'] == 0:
                    x1, y1, x2, y2 = map(int, gt_box_clses[img_name][j]['bbox'])
                    badcases[img_name].append([x1, y1, x2, y2, cls_id, 0, 'fn'])

        TP = np.array(TP)
        FP = np.array(FP)
        sum_FP = np.cumsum(FP)
        sum_TP = np.cumsum(TP)
        recall = sum_TP / num_pos if num_pos > 0 else 0
        num_pred = np.maximum((sum_TP + sum_FP), 1e-4)
        precision = np.divide(sum_TP, num_pred)
        ap, mprec, mrec, i = calculateAveragePrecision(recall, precision)
        r = {
            'class': cls_id,
            'precision': precision,
            'recall': recall,
            'AP': ap,
            'total positives': num_pos,
            'total TP': np.sum(TP),
            'total FP': np.sum(FP),
            'every_tp': TP,
            'every_fp': FP,
            'ious': ious,
            'scores': scores,
            'pred_boxes': pred_boxes,
            'gt_boxes': gt_boxes,
        }
        ret.append(r)
    return ret, badcases

def calculateAveragePrecision(rec, prec):
    mrec = []
    mrec.append(0)
    [mrec.append(e) for e in rec]
    mrec.append(1)
    mpre = []
    mpre.append(0)
    [mpre.append(e) for e in prec]
    mpre.append(0)

    for i in range(len(mpre) - 1, 0, -1):
        mpre[i - 1] = max(mpre[i - 1], mpre[i])
    ii = []
    for i in range(len(mrec) - 1):
        if mrec[i+1] != mrec[i]:
            ii.append(i + 1)
    ap = 0
    for i in ii:
        ap = ap + np.sum((mrec[i] - mrec[i - 1]) * mpre[i])
    return [ap, mpre[0:len(mpre) - 1], mrec[0:len(mpre) - 1], ii]                     

def print_results(results, badcases, params):
    """
    用于打印评估结果, 并根据需要保存异常样本
    Args:
    results, badcases: get_det_pr返回的结果
    params (dict): 评测和可视化badcase的相关参数,包含以下字段:
        {
        'thresholds': 测试阈值列表,
        'image_root': 图片根目录路径,
        'save_root': 保存badcase的文件夹路径,
        'save_badcase': 是否保存badcase
        'save_badcase_small': 是否保存badcase小图
        }
    """
    test_score_list = params['thresholds']
    test_score_list = [test_score_list] if not isinstance(test_score_list, list) else test_score_list

    validClasses, total_npos, acc_AP = 0, 0, 0
    res_dict = {}
    for cls_index, result in enumerate(results):
        classId = result['class']
        if classId not in res_dict:
            res_dict[classId] = {}

        precision = result['precision']
        recall = result['recall']
        average_precision = result['AP']
        npos = result['total positives']
        total_tp = result['total TP']
        total_fp = result['total FP']
        every_tp = result['every_tp']
        ious = result['ious']
        scores = np.array(result['scores'])
        total_npos += npos

        if npos > 0:
            validClasses = validClasses + 1
            acc_AP = acc_AP + average_precision
            ap_str = "{0:.2f}%".format(average_precision * 100)
            print('------------------%s---------------' % classId)
            print('AP: %s (%s)' % (ap_str, classId))
            print('all results:')
            print('    gt positives: %s' % npos)
            print('    recall: %.4f' % float(total_tp/npos))
            print('    precision: %.4f' % float(total_tp/(total_tp + total_fp)))
            print('    false positives: %s' % int(total_fp))
            print('    false negeatives: %s' % (npos - total_tp))

            for tmp_score_conf in test_score_list:
                index_over = int(sum(scores >= tmp_score_conf))
                if len(ious)>0 and index_over>0:
                    print('\ninfer results(score over ' + str(tmp_score_conf) + '):')
                    #print('\ninfer results(score over ' + str(cls_conf[cls_index]) + '):')
                    cur_recall = sum(every_tp[:index_over])/npos
                    print('    recall: %.4f' % (cur_recall))
                    cur_pre = sum(every_tp[:index_over])/index_over
                    print('    precision: %.4f' % (cur_pre))
                    cur_f1 = 2 * cur_pre * cur_recall / (cur_pre + cur_recall)
                    print('    f1 score: %.4f' % (cur_f1))
                    print('    false positives: %s' % int(index_over-sum(every_tp[:index_over])))
                    print('    mean iou: %.4f' % (float(sum(ious[:index_over]))/index_over))
    mAP = acc_AP / validClasses
    mAP_str = "{0:.2f}%".format(mAP * 100)
    print('\n\n')
    print('mAP: %s' % mAP_str)

    if params['save_badcase']:
        print('\nbadcase save at {}\n'.format(params['save_root']))
        if params['save_badcase_small']:
            for img_name, cases in badcases.items():
                img = None
                for case in cases:
                    x1, y1, x2, y2, cls_id, score, ty = case
                    if ty in ['tp', 'fp'] and score < test_score_list[0]:
                        continue
                    if ty in ['fp', 'fn']:
                        if img is None:
                            img = cv2.imread(os.path.join(params['image_root'], img_name))
                        
                        if img is not None:
                            little_picture = img[y1:y2, x1:x2]

                            little_picture_root = os.path.join(params['save_root'], ty)
                            if not os.path.exists(little_picture_root):
                                os.makedirs(little_picture_root)

                            cv2.imwrite(os.path.join(little_picture_root, '{}_{}-{}_{}-{}_{:.2f}_{}'.format(cls_id, x1, y1, x2, y2, score, img_name)), little_picture)

        for img_name, cases in badcases.items():
            img = None
            for case in cases:
                x1, y1, x2, y2, cls_id, score, ty = case
                if ty in ['tp', 'fp'] and score < test_score_list[0]:
                    continue
                if ty in ['fp', 'fn', 'tp']:
                    if img is None:
                        img = cv2.imread(os.path.join(params['image_root'], img_name))
                    cv2.rectangle(img, (x1, y1), (x2, y2), COLORS[ty], 2)
                    cv2.putText(img, '{}-{}-{:.2f}'.format(ty, cls_id, score), (x1, y1-10), cv2.FONT_HERSHEY_COMPLEX, 2, COLORS[ty], 2)
            if img is not None:
                cv2.imwrite(os.path.join(params['save_root'], img_name), img)

def main(config_file, gt_file, pred_file, tasks, save_badcase, save_badcase_small):
    """ Main entry function 
    
    Args:
        config_file: 模型config文件
        gt_file: 标签文件
        pred_file: 标准化框架或sdk预测结果文件
        tasks: 用户指定的要评测的任务
        save_badcase: 是否要保存badcase
        save_badcase_small: 是否要保存badcase小图
    """

    # 加载标准化训练配置文件,主要解析评测相关配置
    config = load_config(config_file)
    assert 'Test' in config, "config_file中需要包含Test字段"
    config = config['Test']

    # 自动化生成相关配置文件
    gt_attrs, eval_tasks = parse_test_config(config, tasks)

    assert os.path.exists(gt_file), 'gt file path: {} does not exist!'.format(gt_file)
    assert os.path.exists(pred_file), 'pred file path: {} does not exist!'.format(pred_file)

    # 该阈值用于检测框各类别分数过滤
    thresholds = 0 if 'bbox_thresholds' not in config else config['bbox_thresholds']
    if save_badcase:
        assert not isinstance(thresholds, list), "当需要保存badcase时, bbox_thresholds不能是list, 需要配置成单个数值"
    if not isinstance(thresholds, list):
        thresholds = [thresholds]
    
    merge_class = config.get('merge_class', None)
    # 解析 gt_file 文件
    gts = parse_gt_file(gt_file, gt_attrs, eval_tasks, merge_class=merge_class)
    # 解析 pred_file 文件
    preds = parse_pred_file(pred_file, eval_tasks, threshold=thresholds[0], merge_class=merge_class)

    # badcase 保存参数设置
    params = {}
    params['save_badcase'] = save_badcase
    params['save_badcase_small'] = save_badcase_small
    params['thresholds'] = thresholds
    params['image_root'] = config['img_root']
    params['save_root'] = os.path.join(params['image_root'], '../badcase')
    if save_badcase:
        if not os.path.exists(params['save_root']):
            os.makedirs(params['save_root'])
        else:
            new_root = os.path.join(params['image_root'], '../badcase_{}'.format(int(time.time())))
            shutil.move(params['save_root'], new_root)
            print('\nmove {} to {}\n'.format(params['save_root'], new_root))
            os.makedirs(params['save_root'])
    
    print('\nStandardization Eval Metrics\n')
    for task in eval_tasks:
        if task in ['gt_class', 'ignore']:
            continue
        metric = gt_attrs[task]['metric']
        if metric == 'ClsPR':
            get_cls_pr(
                gt_attrs[task]['num_classes'], gts, preds, task, params)
        elif metric == 'DetPR':
            results, badcases = get_det_pr(gts, preds, config['iou_threshold'])
            print_results(results, badcases, params)

if __name__ == "__main__":

    args = parse_args()
    config_file = args.config_file
    gt_file = args.gt_file
    pred_file = args.pred_file
    save_badcase = args.save_badcase
    save_badcase_small = args.save_badcase_small
    tasks = args.tasks

    main(config_file, gt_file, pred_file, tasks, save_badcase, save_badcase_small)
  • 6
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值