将模型真实检测结果(包括TP,FPFN,对应正检,漏检,错检)统计下来,并将结果映射于图片。适用于yolo格式的标签和检测生成的yolo标签
需要先将检测结果保存,如下所示
import os
import cv2
import tqdm
import shutil
import numpy as np
def xywh2xyxy(box):
box[:, 0] = box[:, 0] - box[:, 2] / 2
box[:, 1] = box[:, 1] - box[:, 3] / 2
box[:, 2] = box[:, 0] + box[:, 2]
box[:, 3] = box[:, 1] + box[:, 3]
return box
def iou(box1, box2):
x11, y11, x12, y12 = np.split(box1, 4, axis=1)
x21, y21, x22, y22 = np.split(box2, 4, axis=1)
xa = np.maximum(x11, np.transpose(x21))
xb = np.minimum(x12, np.transpose(x22))
ya = np.maximum(y11, np.transpose(y21))
yb = np.minimum(y12, np.transpose(y22))
area_inter = np.maximum(0, (xb - xa + 1)) * np.maximum(0, (yb - ya + 1))
area_1 = (x12 - x11 + 1) * (y12 - y11 + 1)
area_2 = (x22 - x21 + 1) * (y22 - y21 + 1)
area_union = area_1 + np.transpose(area_2) - area_inter
iou = area_inter / area_union
return iou
def draw_box(img, box, color):
cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, thickness=1)
return img
def draw_text(img, text, position, color=(255, 255, 255)):
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 1
font_thickness = 2
cv2.putText(img, text, position, font, font_scale, color, font_thickness, cv2.LINE_AA)
return img
if __name__ == '__main__':
# 推理的图片格式
postfix = 'jpg'
# 已推理图片路径
img_path = 'Y:\yolov5-v7.0\\test_exp\detect\\result'
# 标签路径
label_path = 'Y:\yolov5-v7.0\my_dataset\\neu\labels/test'
# 已推理标签路径
predict_path = 'Y:\yolov5-v7.0\\test_exp\detect\\result\labels'
# 文件保存路径(不要放在其他已经生成过文件的同级目录下,需要新建目录)
save_path = 'Y:\yolov5-v7.0/test_exp\TP_FP_FN'
# 目标类别
classes = ['Cr', 'In', 'pa', 'PS', 'RS', 'Sc']
# classes = ['pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']
detect_color, missing_color, error_color = (0, 255, 0), (0, 0, 255), (255, 0, 0)
iou_threshold = 0.45
if os.path.exists(save_path):
shutil.rmtree(save_path)
os.makedirs(save_path, exist_ok=True)
os.makedirs(os.path.join(save_path, '统计结果'), exist_ok=True) # 创建TN_FN_FP子目录
total_stats = {cls: {'TP': 0, 'FP': 0, 'FN': 0} for cls in classes}
with open(os.path.join(save_path, '统计结果', 'result.txt'), 'w') as f_w:
for path in tqdm.tqdm(os.listdir(label_path)):
image = cv2.imread(f'{img_path}/{path[:-4]}.{postfix}')
if image is None:
print(f'image:{img_path}/{path[:-4]}.{postfix} 未找到.', file=f_w)
h, w = image.shape[:2]
try:
with open(f'{predict_path}/{path}') as f:
pred = np.array(list(map(lambda x: np.array(x.strip().split(), dtype=np.float32), f.readlines())))
pred[:, 1:5] = xywh2xyxy(pred[:, 1:5])
pred[:, [1, 3]] *= w
pred[:, [2, 4]] *= h
pred = list(pred)
except:
pred = []
try:
with open(f'{label_path}/{path}') as f:
label = np.array(list(map(lambda x: np.array(x.strip().split(), dtype=np.float32), f.readlines())))
label[:, 1:] = xywh2xyxy(label[:, 1:])
label[:, [1, 3]] *= w
label[:, [2, 4]] *= h
except:
print(f'label path:{label_path}/{path} (未找到或无目标).', file=f_w)
class_stats = {cls: {'TP': 0, 'FP': 0, 'FN': 0} for cls in classes}
for i in range(label.shape[0]):
if len(pred) == 0: break
ious = iou(label[i:i + 1, 1:], np.array(pred)[:, 1:5])[0]
ious_argsort = ious.argsort()[::-1]
missing = True
for j in ious_argsort:
if ious[j] < iou_threshold: break
if label[i, 0] == pred[j][0]:
image = draw_box(image, pred[j][1:5], detect_color)
pred.pop(j)
missing = False
class_stats[classes[int(label[i, 0])]]['TP'] += 1
break
if missing:
image = draw_box(image, label[i][1:5], missing_color)
class_stats[classes[int(label[i, 0])]]['FP'] += 1
if len(pred):
for j in range(len(pred)):
image = draw_box(image, pred[j][1:5], error_color)
class_stats[classes[int(pred[j][0])]]['FN'] += 1
print(f'name:{path[:-4]}', file=f_w)
for cls in classes:
stat = class_stats[cls]
print(f"{cls}: right(TP):{stat['TP']} missing(FP):{stat['FP']} error(FN):{stat['FN']}", file=f_w)
total_stats[cls]['TP'] += stat['TP']
total_stats[cls]['FP'] += stat['FP']
total_stats[cls]['FN'] += stat['FN']
total_TP = sum(stat['TP'] for stat in class_stats.values())
total_FP = sum(stat['FP'] for stat in class_stats.values())
total_FN = sum(stat['FN'] for stat in class_stats.values())
print(f"单次统计: right(TP):{total_TP} missing(FP):{total_FP} error(FN):{total_FN}", file=f_w)
# 在图片上映射TP_FP_FN
draw_text(image, f"TP: {total_TP}", (10, 40), color=(0, 255, 0)) # TP 对应绿色
draw_text(image, f"FP: {total_FP}", (10, 75), color=(0, 0, 255)) # FP 对应蓝色
draw_text(image, f"FN: {total_FN}", (10, 110), color=(255, 0, 0)) # FN 对应红色
cv2.imwrite(f'{save_path}/{path[:-4]}.{postfix}', image)
for cls, stat in total_stats.items():
print(f"{cls}: right(TP):{stat['TP']} missing(FP):{stat['FP']} error(FN):{stat['FN']}", file=f_w)
all_TP = sum(stat['TP'] for stat in total_stats.values())
all_FP = sum(stat['FP'] for stat in total_stats.values())
all_FN = sum(stat['FN'] for stat in total_stats.values())
all_result = f"最终结果: right(TP):{all_TP} missing(FP):{all_FP} error(FN):{all_FN}"
print(all_result, file=f_w)
print(total_stats)
print(all_result)
效果