# -*- coding: UTF-8 -*-
import argparse
import os
import cv2
import time
import json
import shutil
import numpy as np
from collections import defaultdict
from ppdet.core.workspace import load_config
COLORS = {'tp': [0, 255, 0], 'fp': [0, 0, 255], 'fn': [255, 0, 0]}
def parse_args():
"""参数解析
"""
parser = argparse.ArgumentParser(description="you should add those parameter")
parser.add_argument('--config_file', type=str, default="", help="paddle train config yaml.")
parser.add_argument('--gt_file', type=str, default="", help="ground truth file path.")
parser.add_argument('--pred_file', type=str, default="", help="pred file path.")
parser.add_argument('--save_badcase', action="store_true", help="save badcase images")
parser.add_argument('--save_badcase_small', action="store_true", help="save small badcase images")
parser.add_argument('--tasks', type=str, default="all", help="tasks need to evaluate")
args = parser.parse_args()
return args
def parse_test_config(config, tasks):
"""解析评测集标签的各任务属性信息:attributes,以及准备评测的任务tasks。
Args:
config: 模型训练参数文件;
tasks:用户指定的要进行评测的任务;
Return:
attributes: dict:{attr_name: attr_info},用于描述每个任务属性的基本信息(任务索引,评测metric)
eval_tasks: 根据用户指定的tasks和attributes解析得到的要评测的任务
"""
attributes = {}
for attr in config['attributes']:
attr_name = list(attr)[0]
attr_info = attr[attr_name]
print("attr_name:", attr_name)
print("attr_info", attr_info)
if isinstance(attr_info['anno_key'], list):
for i, anno_key in enumerate(attr_info['anno_key']):
attributes[anno_key] = {'anno_key': anno_key}
attributes[anno_key]['num_classes'] = attr_info['num_classes'][i]
attributes[anno_key]['metric'] = attr_info['metric']
if 'img' in attr_name:
attributes[anno_key]['anno_level'] = 'img'
else:
attributes[anno_key]['anno_level'] = 'obj'
else:
attributes.update(attr)
eval_tasks = set()
if tasks == 'all': # 如果指定的评测任务是all,则默认attributes里的任务都参评
for attr_name, attr in attributes.items():
eval_tasks.add(attr_name)
else:
tasks = tasks.split(',')
for task in tasks:
assert task in attributes, "task: {} not in attributes!".format(task)
eval_tasks.add(task)
if task == 'gt_bbox':
eval_tasks.add('gt_class')
return attributes, eval_tasks
def parse_gt_file(gt_file, gt_attrs, eval_tasks, merge_class):
"""GT 文件解析"""
if gt_file.endswith('.json'):
gts = parse_json_gtfile(gt_file, gt_attrs, eval_tasks, merge_class)
else:
gts = parse_txt_gtfile(gt_file, gt_attrs, eval_tasks, merge_class)
return gts
def parse_json_gtfile(gt_file, gt_attrs, eval_tasks, merge_class):
"""解析COCO格式gt file"""
from pycocotools.coco import COCO
coco = COCO(gt_file)
gts = defaultdict(dict)
if 'gt_bbox' in eval_tasks:
gts['boxes'] = defaultdict(list)
gts['box_clses'] = defaultdict(dict)
img_ids = coco.getImgIds()
for img_id in img_ids:
img_anno = coco.loadImgs([img_id])[0]
img_name = img_anno['file_name'].split('/')[-1]
im_w = float(img_anno['width'])
im_h = float(img_anno['height'])
ins_anno_ids = coco.getAnnIds(imgIds=[img_id], iscrowd=False)
instances = coco.loadAnns(ins_anno_ids)
obj_tasks = []
for task in eval_tasks:
attr_info = gt_attrs[task]
anno_level = attr_info['anno_level']
if anno_level == 'obj':
obj_tasks.append(task)
continue
anno_key = attr_info['anno_key']
gts[img_name][task] = int(img_anno[anno_key]) + 1
if len(obj_tasks) != 0:
for instance in instances:
one_box = {}
for task in obj_tasks:
attr_info = gt_attrs[task]
anno_key = attr_info['anno_key']
if task == 'gt_bbox' and 'bbox' in instance:
x, y, box_w, box_h = instance['bbox']
x1 = max(0, x)
y1 = max(0, y)
x2 = min(im_w - 1, x + max(0, box_w - 1))
y2 = min(im_h - 1, y + max(0, box_h - 1))
one_box['box'] = [x1, y1, x2, y2]
elif task == 'gt_class':
cls_id = instance['category_id'] # 从0开始的
if merge_class is not None:
merged = False
for merge_id, ids in merge_class.items():
if cls_id in ids:
cls_id = merge_id
merged = True
break
assert merged, "cls_id: {} not in merge_class".format(cls_id)
one_box['cls_id'] = cls_id
else:
one_box[anno_key] = instance[anno_key]
if len(one_box['box']) != 0:
# TODO json中如果存在ignore字段,更新该ignore值
ignore = 0
gts['boxes'][cls_id].append([one_box['box'][0], one_box['box'][1], one_box['box'][2], one_box['box'][3], ignore, img_name])
if img_name not in gts['box_clses'][one_box['cls_id']]:
gts['box_clses'][one_box['cls_id']][img_name] = []
gts['box_clses'][one_box['cls_id']][img_name].append({
'bbox': [one_box['box'][0], one_box['box'][1], one_box['box'][2], one_box['box'][3]],
'matched': 0,
'ignore': ignore
})
return gts
def parse_txt_gtfile(gt_file, gt_attrs, eval_tasks, merge_class):
"""解析txt格式的标签文件
Args:
gt_file:标签文件
gt_attrs:描述标签文件格式的attributes对象,根据config解析得到
eval_tasks:要评测的任务
merge_class: 根据配置合并检测框类别, 例:merge_class={1: [1, 2, 3]}, 将cls_id为[1, 2, 3]的值重置为1
"""
gts = defaultdict(dict)
if 'gt_bbox' in eval_tasks:
gts['boxes'] = defaultdict(list)
gts['box_clses'] = defaultdict(dict)
with open(gt_file, 'r') as f:
for line in f.readlines():
words = line.strip().split()
img_name = words[0]
box = []
for task in eval_tasks:
ignore = 0
attr_info = gt_attrs[task]
anno_key = attr_info['anno_key']
# 当标签格式是txt时,anno_key的值含义应该是gt_label的索引,类型为整型,其他类型均过滤; 索引值大于标签长度也过滤。
if isinstance(anno_key, str) or anno_key >= len(words):
continue
if task == 'gt_bbox':
x1, y1, x2, y2 = map(float, words[anno_key: anno_key + 4])
box = [x1, y1, x2, y2]
elif task == 'gt_class':
cls_id = int(words[anno_key])
if merge_class is not None:
merged = False
for merge_id, ids in merge_class.items():
if cls_id in ids:
cls_id = merge_id
merged = True
break
assert merged, "cls_id: {} not in merge_class".format(cls_id)
elif task == 'ignore':
ignore = int(words[anno_key])
else:
gts[img_name][task] = int(words[anno_key]) + 1
if len(box) != 0:
gts['boxes'][cls_id].append([box[0], box[1], box[2], box[3], 0, ignore, img_name])
if img_name not in gts['box_clses'][cls_id]:
gts['box_clses'][cls_id][img_name] = []
# gts['box_clses'][cls_id][img_name].append([box[0], box[1], box[2], box[3], 0, 0])
gts['box_clses'][cls_id][img_name].append({
'bbox': [box[0], box[1], box[2], box[3]],
'matched': 0,
'ignore': ignore
})
if ignore == 1:
print(gts['box_clses'][cls_id][img_name])
return gts
def parse_pred_file(pred_file, eval_tasks, threshold=0., merge_class=None):
"""解析标准化输出结果文件
Args:
pred_file:预测结果文件
eval_tasks:要评测的任务
threshold: 检测框置信度阈值
merge_class: 根据配置合并检测框类别, 例:merge_class={1: [1, 2, 3]}, 将cls_id为[1, 2, 3]的值重置为1
"""
preds = defaultdict(dict)
if 'gt_bbox' in eval_tasks:
preds['boxes'] = defaultdict(list)
preds['box_clses'] = defaultdict(dict)
with open(pred_file, 'r') as f:
for line in f.readlines():
img_name, result_json = line.strip().split()
result_dict = json.loads(result_json)
# 解析每个分类和属性预测结果
if 'attrs' in result_dict:
for attr in result_dict['attrs']:
task = attr['id'].replace('_out', '')
preds[img_name][task] = attr['typeId']
# 解析每个检测框结果
if 'boxs' in result_dict:
for box in result_dict['boxs']:
if 'attrs' not in box:
# 解析标准化框架推理得到的result文件
cls_id = box['typeId']
score = box['confidence']
x1, y1, x2, y2 = box['box']['topLeft']['x'], box['box']['topLeft']['y'], box['box']['bottomRight']['x'], box['box']['bottomRight']['y']
else:
# 解析sdk推理得到的result文件
for attr in box['attrs']:
if '_cls' in attr['id']:
cls_id = attr['typeId']
if '_conf' in attr['id']:
score = attr['confidence']
x1, y1, x2, y2 = box['box']['topLeft']['x'], box['box']['topLeft']['y'], box['box']['buttonRight']['x'], box['box']['buttonRight']['y']
if score < threshold:
continue
if merge_class is not None:
merged = False
for merge_id, ids in merge_class.items():
if cls_id in ids:
cls_id = merge_id
merged = True
break
assert merged, "cls_id: {} not in merge_class".format(cls_id)
preds['boxes'][cls_id].append([x1, y1, x2, y2, score, 0, img_name])
if img_name not in preds['box_clses'][cls_id]:
preds['box_clses'][cls_id][img_name] = []
# preds['box_clses'][cls_id][img_name].append([x1, y1, x2, y2, score, 0])
preds['box_clses'][cls_id][img_name].append({
'bbox': [x1, y1, x2, y2],
'score': score
})
return preds
def get_cls_pr(num_classes, gts, preds, task, badcase_param):
total = 0
correct = 0
conf_mat = np.zeros((num_classes + 1, num_classes + 1), dtype=np.int32)
for img_name in gts:
if img_name not in preds:
continue
gt = gts[img_name][task]
pred = preds[img_name][task]
total += 1
conf_mat[pred, gt] += 1
if gt == pred:
correct += 1
else:
if badcase_param['save_badcase']:
new_image_root = os.path.join(
badcase_param['save_root'],
task,
'gt' + str(gt) + '_' + 'pred' + str(pred))
if not os.path.exists(new_image_root):
os.makedirs(new_image_root)
# 保存 badcase
image_path = os.path.join(badcase_param['image_root'], img_name)
new_image_path = os.path.join(new_image_root, img_name)
os.system('cp {} {}'.format(image_path, new_image_path))
acc = correct / total if total > 0 else 0
print('Total Acc for {}: {:.4f}, total: {}'.format(task, acc, total))
for cls_id in range(1, num_classes + 1):
precision = conf_mat[cls_id, cls_id] / np.sum(conf_mat[cls_id, :]) if np.sum(conf_mat[cls_id, :]) > 0 else 0
recall = conf_mat[cls_id, cls_id] / np.sum(conf_mat[:, cls_id]) if np.sum(conf_mat[:, cls_id]) > 0 else 0
print('Class: {}\tPrecison: {:.4f}\tRecall: {:.4f}\tTP: {}\tFP: {}\tFN: {}\tTOTAL: {}'.format(
cls_id, precision, recall, conf_mat[cls_id, cls_id], np.sum(conf_mat[cls_id, :]) - conf_mat[cls_id, cls_id],
np.sum(conf_mat[:, cls_id]) - conf_mat[cls_id, cls_id], np.sum(conf_mat[:, cls_id])
))
def compute_iou(bbox1, bbox2):
xmin1, ymin1, xmax1, ymax1 = bbox1
xmin2, ymin2, xmax2, ymax2 = bbox2
# 获取矩形框交集对应的顶点坐标(intersection)
xx1 = np.max([xmin1, xmin2])
yy1 = np.max([ymin1, ymin2])
xx2 = np.min([xmax1, xmax2])
yy2 = np.min([ymax1, ymax2])
# 计算交集面积
inter_area = (np.max([0, xx2 - xx1])) * (np.max([0, yy2 - yy1]))
# 计算两个矩形框面积
area1 = (xmax1 - xmin1 ) * (ymax1 - ymin1)
area2 = (xmax2 - xmin2) * (ymax2 - ymin2)
# 计算交并比(交集/并集)
union = (area1 + area2 - inter_area)
if union <= 0:
iou = 0
else:
iou = inter_area / (area1 + area2 - inter_area ) # 注意:这里inter_area不能乘以2,乘以2就相当于把交集部分挖空了
return iou
def get_det_pr(gts, preds, iou_threshold):
"""
根据给定的ground truths 和 predictions,计算预测框PR和MAP指标
Args:
gts(dict): ground truths, 包含'boxes', 'box_clses'字段, 键值为cls_id,
其中boxes的值为[[x1, y1, x2, y2, score, 0, img_name], ...];
box_clses的值为{'img_name': [{'bbox', 'matched', 'ignore'}, ...], ...}
preds(dict): predictions, 包含'boxes'字段, 键值为cls_id, 格式和gts基本一致;
iou_threshold: 评测时, 匹配框iou阈值
Returns:
tuple: 返回一个元组(ret, badcases);
其中 ret 为一个列表,元素为一个字典,包含以下字段:
class (int): 当前类的id。
precision (ndarray): 存储precision的数组。
recall (ndarray): 存储recall的数组。
AP (float): 当前类平均准确率。
total positives (int): 该类中正样本的总数。
total TP (int): 该类中true positive 的总数。
total FP (int): 该类中false positive 的总数。
every_tp (ndarray): 每个预测框的是TP的标志符。
every_fp (ndarray): 每个预测框的是FP的标志符。
ious (ndarray): 所有预测框对应的IoU。
scores (ndarray): 所有预测框对应的score。
pred_boxes (List[ndarray]): 所有预测框的信息。
gt_boxes (List[ndarray]): 所有真值框的信息。
badcases是dict: 返回一个字典,键值为图像名,值为一个列表,每个元素为每个框信息[x1, y1, x2, y2, cls_id, score, badcase_type]
"""
ret, classes = [], []
for cls_id in gts['boxes']:
if cls_id in preds['boxes']:
classes.append(cls_id)
else:
print('class: {} not in preds!'.format(cls_id))
badcases = defaultdict(list)
for cls_id in classes:
pred_boxes = preds['boxes'][cls_id] # [[x1, y1, x2, y2, score, 0, img_name], ...]
gt_boxes = gts['boxes'][cls_id] # [[x1, y1, x2, y2, 0, 0, img_name], ...]
gt_box_clses = gts['box_clses'][cls_id] # {'img_name': [{'bbox', 'matched', 'ignore'}, ...], ...}
np_gt_boxes = np.array(gt_boxes)
num_pos = np.sum(np_gt_boxes[:, -2] == '0')
pred_boxes = sorted(pred_boxes, key=lambda conf: conf[4], reverse=True) # 按score从大到小排序
TP, FP = [], []
ious, scores = [], []
for pred_box in pred_boxes: # 遍历预测框
iouMax, jMax = 0, 0
img_name = pred_box[-1]
score = pred_box[4]
x1, y1, x2, y2 = map(int, pred_box[: 4])
if img_name in gt_box_clses:
for j in range(len(gt_box_clses[img_name])): # 遍历真值框,找到iou值最大的框索引
iou = compute_iou(pred_box[:4], gt_box_clses[img_name][j]['bbox'])
if iou > iouMax:
iouMax = iou
jMax = j
if iouMax > iou_threshold:
if gt_box_clses[img_name][jMax]['matched'] == 0: # 真值框没被匹配过
if gt_box_clses[img_name][jMax]['ignore'] == 0: # 且不是Ignore
TP.append(1)
FP.append(0)
ious.append(iouMax)
scores.append(score)
gt_box_clses[img_name][jMax]['matched'] = 1
badcases[img_name].append([x1, y1, x2, y2, cls_id, score, 'tp'])
else:
if gt_box_clses[img_name][jMax]['ignore'] == 0:
TP.append(0)
FP.append(1)
scores.append(score)
badcases[img_name].append([x1, y1, x2, y2, cls_id, score, 'fp'])
else:
TP.append(0)
FP.append(1)
scores.append(score)
badcases[img_name].append([x1, y1, x2, y2, cls_id, score, 'fp'])
else:
TP.append(0)
FP.append(1)
scores.append(score)
badcases[img_name].append([x1, y1, x2, y2, cls_id, score, 'fp'])
for img_name in gt_box_clses:
for j in range(len(gt_box_clses[img_name])):
if gt_box_clses[img_name][j]['matched'] == 0 and gt_box_clses[img_name][j]['ignore'] == 0:
x1, y1, x2, y2 = map(int, gt_box_clses[img_name][j]['bbox'])
badcases[img_name].append([x1, y1, x2, y2, cls_id, 0, 'fn'])
TP = np.array(TP)
FP = np.array(FP)
sum_FP = np.cumsum(FP)
sum_TP = np.cumsum(TP)
recall = sum_TP / num_pos if num_pos > 0 else 0
num_pred = np.maximum((sum_TP + sum_FP), 1e-4)
precision = np.divide(sum_TP, num_pred)
ap, mprec, mrec, i = calculateAveragePrecision(recall, precision)
r = {
'class': cls_id,
'precision': precision,
'recall': recall,
'AP': ap,
'total positives': num_pos,
'total TP': np.sum(TP),
'total FP': np.sum(FP),
'every_tp': TP,
'every_fp': FP,
'ious': ious,
'scores': scores,
'pred_boxes': pred_boxes,
'gt_boxes': gt_boxes,
}
ret.append(r)
return ret, badcases
def calculateAveragePrecision(rec, prec):
mrec = []
mrec.append(0)
[mrec.append(e) for e in rec]
mrec.append(1)
mpre = []
mpre.append(0)
[mpre.append(e) for e in prec]
mpre.append(0)
for i in range(len(mpre) - 1, 0, -1):
mpre[i - 1] = max(mpre[i - 1], mpre[i])
ii = []
for i in range(len(mrec) - 1):
if mrec[i+1] != mrec[i]:
ii.append(i + 1)
ap = 0
for i in ii:
ap = ap + np.sum((mrec[i] - mrec[i - 1]) * mpre[i])
return [ap, mpre[0:len(mpre) - 1], mrec[0:len(mpre) - 1], ii]
def print_results(results, badcases, params):
"""
用于打印评估结果, 并根据需要保存异常样本
Args:
results, badcases: get_det_pr返回的结果
params (dict): 评测和可视化badcase的相关参数,包含以下字段:
{
'thresholds': 测试阈值列表,
'image_root': 图片根目录路径,
'save_root': 保存badcase的文件夹路径,
'save_badcase': 是否保存badcase
'save_badcase_small': 是否保存badcase小图
}
"""
test_score_list = params['thresholds']
test_score_list = [test_score_list] if not isinstance(test_score_list, list) else test_score_list
validClasses, total_npos, acc_AP = 0, 0, 0
res_dict = {}
for cls_index, result in enumerate(results):
classId = result['class']
if classId not in res_dict:
res_dict[classId] = {}
precision = result['precision']
recall = result['recall']
average_precision = result['AP']
npos = result['total positives']
total_tp = result['total TP']
total_fp = result['total FP']
every_tp = result['every_tp']
ious = result['ious']
scores = np.array(result['scores'])
total_npos += npos
if npos > 0:
validClasses = validClasses + 1
acc_AP = acc_AP + average_precision
ap_str = "{0:.2f}%".format(average_precision * 100)
print('------------------%s---------------' % classId)
print('AP: %s (%s)' % (ap_str, classId))
print('all results:')
print(' gt positives: %s' % npos)
print(' recall: %.4f' % float(total_tp/npos))
print(' precision: %.4f' % float(total_tp/(total_tp + total_fp)))
print(' false positives: %s' % int(total_fp))
print(' false negeatives: %s' % (npos - total_tp))
for tmp_score_conf in test_score_list:
index_over = int(sum(scores >= tmp_score_conf))
if len(ious)>0 and index_over>0:
print('\ninfer results(score over ' + str(tmp_score_conf) + '):')
#print('\ninfer results(score over ' + str(cls_conf[cls_index]) + '):')
cur_recall = sum(every_tp[:index_over])/npos
print(' recall: %.4f' % (cur_recall))
cur_pre = sum(every_tp[:index_over])/index_over
print(' precision: %.4f' % (cur_pre))
cur_f1 = 2 * cur_pre * cur_recall / (cur_pre + cur_recall)
print(' f1 score: %.4f' % (cur_f1))
print(' false positives: %s' % int(index_over-sum(every_tp[:index_over])))
print(' mean iou: %.4f' % (float(sum(ious[:index_over]))/index_over))
mAP = acc_AP / validClasses
mAP_str = "{0:.2f}%".format(mAP * 100)
print('\n\n')
print('mAP: %s' % mAP_str)
if params['save_badcase']:
print('\nbadcase save at {}\n'.format(params['save_root']))
if params['save_badcase_small']:
for img_name, cases in badcases.items():
img = None
for case in cases:
x1, y1, x2, y2, cls_id, score, ty = case
if ty in ['tp', 'fp'] and score < test_score_list[0]:
continue
if ty in ['fp', 'fn']:
if img is None:
img = cv2.imread(os.path.join(params['image_root'], img_name))
if img is not None:
little_picture = img[y1:y2, x1:x2]
little_picture_root = os.path.join(params['save_root'], ty)
if not os.path.exists(little_picture_root):
os.makedirs(little_picture_root)
cv2.imwrite(os.path.join(little_picture_root, '{}_{}-{}_{}-{}_{:.2f}_{}'.format(cls_id, x1, y1, x2, y2, score, img_name)), little_picture)
for img_name, cases in badcases.items():
img = None
for case in cases:
x1, y1, x2, y2, cls_id, score, ty = case
if ty in ['tp', 'fp'] and score < test_score_list[0]:
continue
if ty in ['fp', 'fn', 'tp']:
if img is None:
img = cv2.imread(os.path.join(params['image_root'], img_name))
cv2.rectangle(img, (x1, y1), (x2, y2), COLORS[ty], 2)
cv2.putText(img, '{}-{}-{:.2f}'.format(ty, cls_id, score), (x1, y1-10), cv2.FONT_HERSHEY_COMPLEX, 2, COLORS[ty], 2)
if img is not None:
cv2.imwrite(os.path.join(params['save_root'], img_name), img)
def main(config_file, gt_file, pred_file, tasks, save_badcase, save_badcase_small):
""" Main entry function
Args:
config_file: 模型config文件
gt_file: 标签文件
pred_file: 标准化框架或sdk预测结果文件
tasks: 用户指定的要评测的任务
save_badcase: 是否要保存badcase
save_badcase_small: 是否要保存badcase小图
"""
# 加载标准化训练配置文件,主要解析评测相关配置
config = load_config(config_file)
assert 'Test' in config, "config_file中需要包含Test字段"
config = config['Test']
# 自动化生成相关配置文件
gt_attrs, eval_tasks = parse_test_config(config, tasks)
assert os.path.exists(gt_file), 'gt file path: {} does not exist!'.format(gt_file)
assert os.path.exists(pred_file), 'pred file path: {} does not exist!'.format(pred_file)
# 该阈值用于检测框各类别分数过滤
thresholds = 0 if 'bbox_thresholds' not in config else config['bbox_thresholds']
if save_badcase:
assert not isinstance(thresholds, list), "当需要保存badcase时, bbox_thresholds不能是list, 需要配置成单个数值"
if not isinstance(thresholds, list):
thresholds = [thresholds]
merge_class = config.get('merge_class', None)
# 解析 gt_file 文件
gts = parse_gt_file(gt_file, gt_attrs, eval_tasks, merge_class=merge_class)
# 解析 pred_file 文件
preds = parse_pred_file(pred_file, eval_tasks, threshold=thresholds[0], merge_class=merge_class)
# badcase 保存参数设置
params = {}
params['save_badcase'] = save_badcase
params['save_badcase_small'] = save_badcase_small
params['thresholds'] = thresholds
params['image_root'] = config['img_root']
params['save_root'] = os.path.join(params['image_root'], '../badcase')
if save_badcase:
if not os.path.exists(params['save_root']):
os.makedirs(params['save_root'])
else:
new_root = os.path.join(params['image_root'], '../badcase_{}'.format(int(time.time())))
shutil.move(params['save_root'], new_root)
print('\nmove {} to {}\n'.format(params['save_root'], new_root))
os.makedirs(params['save_root'])
print('\nStandardization Eval Metrics\n')
for task in eval_tasks:
if task in ['gt_class', 'ignore']:
continue
metric = gt_attrs[task]['metric']
if metric == 'ClsPR':
get_cls_pr(
gt_attrs[task]['num_classes'], gts, preds, task, params)
elif metric == 'DetPR':
results, badcases = get_det_pr(gts, preds, config['iou_threshold'])
print_results(results, badcases, params)
if __name__ == "__main__":
args = parse_args()
config_file = args.config_file
gt_file = args.gt_file
pred_file = args.pred_file
save_badcase = args.save_badcase
save_badcase_small = args.save_badcase_small
tasks = args.tasks
main(config_file, gt_file, pred_file, tasks, save_badcase, save_badcase_small)
PR计算逻辑
于 2024-03-10 16:44:15 首次发布