TP FP FN统计结果并映射到图片上

 将模型真实检测结果(包括TP,FPFN,对应正检,漏检,错检)统计下来,并将结果映射于图片。适用于yolo格式的标签和检测生成的yolo标签

需要先将检测结果保存,如下所示

import os
import cv2
import tqdm
import shutil
import numpy as np

def xywh2xyxy(box):
    box[:, 0] = box[:, 0] - box[:, 2] / 2
    box[:, 1] = box[:, 1] - box[:, 3] / 2
    box[:, 2] = box[:, 0] + box[:, 2]
    box[:, 3] = box[:, 1] + box[:, 3]
    return box

def iou(box1, box2):
    x11, y11, x12, y12 = np.split(box1, 4, axis=1)
    x21, y21, x22, y22 = np.split(box2, 4, axis=1)

    xa = np.maximum(x11, np.transpose(x21))
    xb = np.minimum(x12, np.transpose(x22))
    ya = np.maximum(y11, np.transpose(y21))
    yb = np.minimum(y12, np.transpose(y22))

    area_inter = np.maximum(0, (xb - xa + 1)) * np.maximum(0, (yb - ya + 1))

    area_1 = (x12 - x11 + 1) * (y12 - y11 + 1)
    area_2 = (x22 - x21 + 1) * (y22 - y21 + 1)
    area_union = area_1 + np.transpose(area_2) - area_inter

    iou = area_inter / area_union
    return iou


def draw_box(img, box, color):
    cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, thickness=1)
    return img

def draw_text(img, text, position, color=(255, 255, 255)):
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 1
    font_thickness = 2
    cv2.putText(img, text, position, font, font_scale, color, font_thickness, cv2.LINE_AA)
    return img

if __name__ == '__main__':
    # 推理的图片格式
    postfix = 'jpg'
    # 已推理图片路径
    img_path = 'Y:\yolov5-v7.0\\test_exp\detect\\result'
    # 标签路径
    label_path = 'Y:\yolov5-v7.0\my_dataset\\neu\labels/test'
    # 已推理标签路径
    predict_path = 'Y:\yolov5-v7.0\\test_exp\detect\\result\labels'
    # 文件保存路径(不要放在其他已经生成过文件的同级目录下,需要新建目录)
    save_path = 'Y:\yolov5-v7.0/test_exp\TP_FP_FN'
    # 目标类别
    classes = ['Cr', 'In', 'pa', 'PS', 'RS', 'Sc']
    # classes = ['pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']
    detect_color, missing_color, error_color = (0, 255, 0), (0, 0, 255), (255, 0, 0)
    iou_threshold = 0.45

    if os.path.exists(save_path):
        shutil.rmtree(save_path)
    os.makedirs(save_path, exist_ok=True)
    os.makedirs(os.path.join(save_path, '统计结果'), exist_ok=True)  # 创建TN_FN_FP子目录

    total_stats = {cls: {'TP': 0, 'FP': 0, 'FN': 0} for cls in classes}

    with open(os.path.join(save_path, '统计结果', 'result.txt'), 'w') as f_w:
        for path in tqdm.tqdm(os.listdir(label_path)):
            image = cv2.imread(f'{img_path}/{path[:-4]}.{postfix}')
            if image is None:
                print(f'image:{img_path}/{path[:-4]}.{postfix} 未找到.', file=f_w)
            h, w = image.shape[:2]

            try:
                with open(f'{predict_path}/{path}') as f:
                    pred = np.array(list(map(lambda x: np.array(x.strip().split(), dtype=np.float32), f.readlines())))
                    pred[:, 1:5] = xywh2xyxy(pred[:, 1:5])
                    pred[:, [1, 3]] *= w
                    pred[:, [2, 4]] *= h
                    pred = list(pred)
            except:
                pred = []

            try:
                with open(f'{label_path}/{path}') as f:
                    label = np.array(list(map(lambda x: np.array(x.strip().split(), dtype=np.float32), f.readlines())))
                    label[:, 1:] = xywh2xyxy(label[:, 1:])
                    label[:, [1, 3]] *= w
                    label[:, [2, 4]] *= h
            except:
                print(f'label path:{label_path}/{path} (未找到或无目标).', file=f_w)

            class_stats = {cls: {'TP': 0, 'FP': 0, 'FN': 0} for cls in classes}

            for i in range(label.shape[0]):
                if len(pred) == 0: break
                ious = iou(label[i:i + 1, 1:], np.array(pred)[:, 1:5])[0]
                ious_argsort = ious.argsort()[::-1]
                missing = True
                for j in ious_argsort:
                    if ious[j] < iou_threshold: break
                    if label[i, 0] == pred[j][0]:
                        image = draw_box(image, pred[j][1:5], detect_color)
                        pred.pop(j)
                        missing = False
                        class_stats[classes[int(label[i, 0])]]['TP'] += 1
                        break

                if missing:
                    image = draw_box(image, label[i][1:5], missing_color)
                    class_stats[classes[int(label[i, 0])]]['FP'] += 1

            if len(pred):
                for j in range(len(pred)):
                    image = draw_box(image, pred[j][1:5], error_color)
                    class_stats[classes[int(pred[j][0])]]['FN'] += 1

            print(f'name:{path[:-4]}', file=f_w)
            for cls in classes:
                stat = class_stats[cls]
                print(f"{cls}: right(TP):{stat['TP']} missing(FP):{stat['FP']} error(FN):{stat['FN']}", file=f_w)

                total_stats[cls]['TP'] += stat['TP']
                total_stats[cls]['FP'] += stat['FP']
                total_stats[cls]['FN'] += stat['FN']

            total_TP = sum(stat['TP'] for stat in class_stats.values())
            total_FP = sum(stat['FP'] for stat in class_stats.values())
            total_FN = sum(stat['FN'] for stat in class_stats.values())
            print(f"单次统计: right(TP):{total_TP} missing(FP):{total_FP} error(FN):{total_FN}", file=f_w)

            # 在图片上映射TP_FP_FN
            draw_text(image, f"TP: {total_TP}", (10, 40), color=(0, 255, 0))  # TP 对应绿色
            draw_text(image, f"FP: {total_FP}", (10, 75), color=(0, 0, 255))  # FP 对应蓝色
            draw_text(image, f"FN: {total_FN}", (10, 110), color=(255, 0, 0))  # FN 对应红色

            cv2.imwrite(f'{save_path}/{path[:-4]}.{postfix}', image)

        for cls, stat in total_stats.items():
            print(f"{cls}: right(TP):{stat['TP']} missing(FP):{stat['FP']} error(FN):{stat['FN']}", file=f_w)

        all_TP = sum(stat['TP'] for stat in total_stats.values())
        all_FP = sum(stat['FP'] for stat in total_stats.values())
        all_FN = sum(stat['FN'] for stat in total_stats.values())

        all_result = f"最终结果: right(TP):{all_TP} missing(FP):{all_FP} error(FN):{all_FN}"
        print(all_result, file=f_w)
        print(total_stats)
        print(all_result)

效果

  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值