目标检测yolo格式数据,给定置信度阈值和iou阈值时的PR指标计算脚本。
labels文件夹下标签文件格式:cls x y w h
results文件夹下预测结果格式:cls x y w h conf
注意:labels文件夹和results文件夹下的txt名称要对应,比如labels/1.txt对应1.jpg的标签,results/1.txt对应1.jpg的预测结果。
from pathlib import Path
annotation_dir = r'./labels/' #labels的路径
prediction_dir = r'./results/' #预测结果路径
ann_files = Path(annotation_dir).glob('*')
pre_files = Path(prediction_dir).glob('*')
annotation = {}
prediction = {}
CONF_THRE = 0.5 #置信度阈值
IOU_THRE = 0.5 #iou阈值
for ann_file in ann_files:
annotation[Path(ann_file).name] = []
with open(ann_file,'r') as fin:
for line in fin:
line_list = line.split()
x_min = float(line_list[1])-float(line_list[3])/2
y_min = float(line_list[2])-float(line_list[4])/2
x_max = float(line_list[1])+float(line_list[3])/2
y_max = float(line_list[2])+float(line_list[4])/2
obj = [x_min, y_min, x_max, y_max, int(line_list[0])] #x_min, y_min, x_max, y_max, annotation_class_id
annotation[Path(ann_file).name].append(obj)
for pre_file in pre_files:
prediction[Path(pre_file).name] = []
with open(pre_file,'r') as fin:
for line in fin:
line_list = line.split()
x_min = float(line_list[1])-float(line_list[3])/2
y_min = float(line_list[2])-float(line_list[4])/2
x_max = float(line_list[1])+float(line_list[3])/2
y_max = float(line_list[2])+float(line_list[4])/2
obj = [x_min, y_min, x_max, y_max, int(line_list[0])] #x_min, y_min, x_max, y_max, prediction_class_id
if float(line_list[5]) >= CONF_THRE:
prediction[Path(pre_file).name].append(obj)
def calculate_iou(rect1,rect2):
overlap_x1 = max(rect1[0], rect2[0])
overlap_y1 = max(rect1[1], rect2[1])
overlap_x2 = min(rect1[2], rect2[2])
overlap_y2 = min(rect1[3], rect2[3])
if overlap_x2 - overlap_x1 <=0 or overlap_y2 - overlap_y1 <= 0:
return 0
iou_area = (overlap_x2 - overlap_x1)*(overlap_y2 - overlap_y1)
union_area = (rect1[2] - rect1[0]) * (rect1[3] - rect1[1]) + (rect2[2] - rect2[0]) * (rect2[3] - rect2[1]) - iou_area
return float(iou_area) / union_area
correct_obj_num = 0
prediction_obj_num = 0
annotation_obj_num = 0
for key, value in annotation.items():
annotation_obj_num += len(value)
if key in prediction:
prediction_value = prediction[key]
prediction_obj_num += len(prediction_value)
for i in range(len(prediction_value)):
for j in range(len(value)):
iou = calculate_iou(prediction_value[i], value[j])
if iou >= IOU_THRE and prediction_value[i][4] == value[j][4]:
correct_obj_num += 1
break
P = float(correct_obj_num) / prediction_obj_num
R = float(correct_obj_num) / annotation_obj_num
print('Precision ratio: ',P)
print('Recall ratio: ',R)