使用YOLOv8+SAHI增强小目标检测效果并计算评估指标


前言

最近看到不少同学苦恼于想要评估 YOLO + SAHI 的指标,但不知道具体操作步骤,我自己在网上搜到的内容也比较复杂,大部分还要收费,所以就决定自己写一个代码,实现从模型加载、图像处理、检测结果可视化到评估指标计算的全过程。这个代码基本已经做到即插即用啦,支持 YOLOv5、YOLOv8 等多种模型,话不多说,下面就开始讲解吧!(赶时间的小伙伴可以直接跳转到最后复制完整代码)

在这里插入图片描述


必要环境

我们需要安装如下几个库
1、OpenCV (cv2)
2、SAHI
3、tabulate
4、podm
5、tqdm
6、argparse

安装命令如下:

pip install opencv-python sahi tabulate podm tqdm argparse -i  https://pypi.tuna.tsinghua.edu.cn/simple

一、代码结构

1、 训练参数解析

首先,我们利用 argparse 模块来设置命令行参数,以便灵活配置各种变量

parser = argparse.ArgumentParser(description="Object Detection Evaluation Script")
parser.add_argument('--filepath', type=str, default='val/images', help='Path to the images folder')
parser.add_argument('--annotation_folder', type=str, default='val/labels', help='Path to the annotation folder')

parser.add_argument('--model_type', type=str, default='yolov8', help='Type of the detection model')
parser.add_argument('--model_path', type=str, default='kitti_baseline/weights/best.pt',
                    help='Path to the model weights')
parser.add_argument('--confidence_threshold', type=float, default=0.4, help='Confidence threshold for the model')
parser.add_argument('--device', type=str, default="cuda:0", help='Device to run the model on')

parser.add_argument('--slice_height', type=int, default=256, help='Height of the image slices')
parser.add_argument('--slice_width', type=int, default=256, help='Width of the image slices')
parser.add_argument('--overlap_height_ratio', type=float, default=0.2, help='Overlap height ratio for slicing')
parser.add_argument('--overlap_width_ratio', type=float, default=0.2, help='Overlap width ratio for slicing')

parser.add_argument('--visualize_predictions', action='store_true', default=False, help='Visualize prediction results')
parser.add_argument('--visualize_annotations', action='store_true', default=False, help='Visualize annotation results')

parser.add_argument('--class_list', type=str, nargs='+',
                    default=['Pedestrian', 'Car', 'Van', 'Truck', 'Person_sitting', 'Cyclist', 'Tram'],
                    help='List of class names')
parser.add_argument('--images_format', type=str, nargs='+', default=['.png', '.jpg', '.jpeg'],
                    help='List of acceptable image formats')

args = parser.parse_args()

关键参数详解:

  • –filepath: 指定图像文件夹的路径(images)

  • –annotation_folder: 指定标注文件夹的路径 (labels)

  • –model_type: 指定检测模型的类型 (默认为yolov8)

  • –model_path: 指定模型权重文件的路径

  • –confidence_threshold: 指定模型的置信度阈值 (置信度高于这个阈值的检测框才会被保留)

  • –device: 指定运行模型的设备(如 cuda:0 或 cpu)

  • –slice_height: 指定图像切片的高度

  • –slice_width: 指定图像切片的宽度

  • –overlap_height_ratio: 指定切片的高度重叠比例

  • –overlap_width_ratio: 指定切片的宽度重叠比例

  • –visualize_predictions: 如果设置True,将可视化推理结果

  • –visualize_annotations: 如果设置True,将可视化标注结果

  • –class_list: 指定类名列表(可以直接复制数据集.yaml文件中 变量names后面的列表)

  • –images_format: 指定可接受的图像格式列表

2、 核心代码解析

1.加载检测模型

调用AutoDetectionModel.from_pretrained函数来加载YOLOv8模型

def load_detection_model():
    return AutoDetectionModel.from_pretrained(
        model_type=args.model_type,
        model_path=args.model_path,
        confidence_threshold=args.confidence_threshold,
        device=args.device
    )

2. 处理图像

定义 process_image 函数来处理每张图像,该函数包含:

  • 从指定路径读取图像
  • 使用SAHI进行切片预测
  • 读取标注文件的真实框
  • 可视化预测和标注结果
def process_image(image_name, model, labels, detections):
    img_path = os.path.join(args.filepath, image_name)
    img_vis = cv2.imread(img_path)
    img_h, img_w, _ = img_vis.shape

    result = get_sliced_prediction(
        img_path,
        model,
        slice_height=args.slice_height,
        slice_width=args.slice_width,
        overlap_height_ratio=args.overlap_height_ratio,
        overlap_width_ratio=args.overlap_width_ratio,
        verbose = 0
    )

    anno_file = os.path.join(args.annotation_folder, image_name[:-4] + '.txt')
    annotations = read_boxes(anno_file, img_w, img_h)

    for anno in annotations:
        label, xmin_gt, ymin_gt, xmax_gt, ymax_gt = anno
        labels.append(BoundingBox.of_bbox(image_name, label, xmin_gt, ymin_gt, xmax_gt, ymax_gt))
        if args.visualize_annotations:
            cv2.rectangle(img_vis, (int(xmin_gt), int(ymin_gt)), (int(xmax_gt), int(ymax_gt)), get_color(label), 2,
                          cv2.LINE_AA)
            cv2.putText(img_vis, f"{args.class_list[label]}", (int(xmin_gt), int(ymin_gt - 5)),
                        cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(label), thickness=2)

    for pred in result.object_prediction_list:
        bbox = pred.bbox
        cls = pred.category.id
        score = pred.score.value
        xmin_pd, ymin_pd, xmax_pd, ymax_pd = bbox.minx, bbox.miny, bbox.maxx, bbox.maxy
        detections.append(BoundingBox.of_bbox(image_name, cls, xmin_pd, ymin_pd, xmax_pd, ymax_pd, score))

        if args.visualize_predictions:
            cv2.rectangle(img_vis, (int(xmin_pd), int(ymin_pd)), (int(xmax_pd), int(ymax_pd)),
                          get_color(cls + len(args.class_list)), 2, cv2.LINE_AA)
            cv2.putText(img_vis, f"{args.class_list[cls]} {score:.2f}", (int(xmin_pd), int(ymin_pd - 5)),
                        cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(cls + len(args.class_list)), thickness=2)

    if args.visualize_predictions or args.visualize_annotations:
        cv2.imshow(image_name, img_vis)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
2.1 读取图像和标注文件
anno_file = os.path.join(args.annotation_folder, image_name[:-4] + '.txt')
annotations = read_boxes(anno_file, img_w, img_h)
  • anno_file:构建标注文件路径
  • annotations:使用 read_boxes 函数读取标注框
2.2 处理标注框
for anno in annotations:
    label, xmin_gt, ymin_gt, xmax_gt, ymax_gt = anno
    labels.append(BoundingBox.of_bbox(image_name, label, xmin_gt, ymin_gt, xmax_gt, ymax_gt))
    if args.visualize_annotations:
        cv2.rectangle(img_vis, (int(xmin_gt), int(ymin_gt)), (int(xmax_gt), int(ymax_gt)), get_color(label), 2,
                      cv2.LINE_AA)
        cv2.putText(img_vis, f"{args.class_list[label]}", (int(xmin_gt), int(ymin_gt - 5)),
                    cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(label), thickness=2)
  • 循环遍历 annotations 中的每个标注框
  • labels.append:将标注框添加到 labels 列表
  • 如果 args.visualize_annotations 为True,将标注框绘制在图像上,并在框上方显示类别名称
2.3 处理预测结果
for pred in result.object_prediction_list:
    bbox = pred.bbox
    cls = pred.category.id
    score = pred.score.value
    xmin_pd, ymin_pd, xmax_pd, ymax_pd = bbox.minx, bbox.miny, bbox.maxx, bbox.maxy
    detections.append(BoundingBox.of_bbox(image_name, cls, xmin_pd, ymin_pd, xmax_pd, ymax_pd, score))

    if args.visualize_predictions:
        cv2.rectangle(img_vis, (int(xmin_pd), int(ymin_pd)), (int(xmax_pd), int(ymax_pd)),
                      get_color(cls + len(args.class_list)), 2, cv2.LINE_AA)
        cv2.putText(img_vis, f"{args.class_list[cls]} {score:.2f}", (int(xmin_pd), int(ymin_pd - 5)),
                    cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(cls + len(args.class_list)), thickness=2)
  • 循环遍历 result.object_prediction_list 中的每个预测结果
  • detections.append:将预测结果添加到 detections 列表
  • 如果 args.visualize_predictions 为真True,将预测框绘制在图像上,并在框上方显示类别名称和置信度分数
2.4 显示图像
if args.visualize_predictions or args.visualize_annotations:
    cv2.imshow(image_name, img_vis)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
  • 如果 args.visualize_prediction=True 显示推理结果
  • 如果 args.visualize_annotations=True 显示标注结果
  • 如果 args.visualize_prediction=True or args.visualize_annotations=True 同时显示推理结果和标注结果
  • 如果 args.visualize_prediction=False or args.visualize_annotations=False 将不会可视化结果,而是直接计算评估指标

3. 评估模型

我们使用 podm 库来计算 PASCAL VOC 评估指标

def evaluate_model(labels, detections):
    results = get_pascal_voc_metrics(labels, detections, 0.5)
    table = [
        [args.class_list[int(class_id)], m.recall[-1], m.precision[-1], m.ap]
        for class_id, m in results.items() if m.num_groundtruth > 0
    ]
    map_score = MetricPerClass.mAP(results)
    print(tabulate(table, headers=["ClassID", "Recall", "Precision", "AP"], floatfmt=".2f"))
    print(f"\nmAP: {map_score:.4f}")

4. 主函数

主函数中加载模型,遍历图像文件夹,处理每张图像,并在最后评估模型

def main():
    detection_model = load_detection_model()
    image_names = [name for name in os.listdir(args.filepath) if
                   os.path.splitext(name)[1].lower() in args.images_format]
    labels, detections = [], []

    for i, image_name in enumerate(tqdm(image_names, desc="Processing images")):
        process_image(image_name, detection_model, labels, detections)

    evaluate_model(labels, detections)

if __name__ == "__main__":
    main()

二、完整代码

import os
import cv2
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from tabulate import tabulate
from podm.metrics import BoundingBox, get_pascal_voc_metrics, MetricPerClass
import argparse
from tqdm import tqdm

parser = argparse.ArgumentParser(description="Object Detection Evaluation Script")
parser.add_argument('--filepath', type=str, default='val/images', help='Path to the images folder')
parser.add_argument('--annotation_folder', type=str, default='val/labels', help='Path to the annotation folder')

parser.add_argument('--model_type', type=str, default='yolov8', help='Type of the detection model')
parser.add_argument('--model_path', type=str, default='kitti_baseline/weights/best.pt',
                    help='Path to the model weights')
parser.add_argument('--confidence_threshold', type=float, default=0.4, help='Confidence threshold for the model')
parser.add_argument('--device', type=str, default="cuda:0", help='Device to run the model on')

parser.add_argument('--slice_height', type=int, default=256, help='Height of the image slices')
parser.add_argument('--slice_width', type=int, default=256, help='Width of the image slices')
parser.add_argument('--overlap_height_ratio', type=float, default=0.2, help='Overlap height ratio for slicing')
parser.add_argument('--overlap_width_ratio', type=float, default=0.2, help='Overlap width ratio for slicing')

parser.add_argument('--visualize_predictions', action='store_true', default=False, help='Visualize prediction results')
parser.add_argument('--visualize_annotations', action='store_true', default=False, help='Visualize annotation results')

parser.add_argument('--class_list', type=str, nargs='+',
                    default=['Pedestrian', 'Car', 'Van', 'Truck', 'Person_sitting', 'Cyclist', 'Tram'],
                    help='List of class names')
parser.add_argument('--images_format', type=str, nargs='+', default=['.png', '.jpg', '.jpeg'],
                    help='List of acceptable image formats')

args = parser.parse_args()


def get_color(idx):
    idx = int(idx) + 5
    return ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)


def read_boxes(txt_file, img_w, img_h):
    boxes = []
    with open(txt_file, 'r') as f:
        for line in f:
            items = line.strip().split()
            box = [
                int(items[0]),
                (float(items[1]) - float(items[3]) / 2) * img_w,
                (float(items[2]) - float(items[4]) / 2) * img_h,
                (float(items[1]) + float(items[3]) / 2) * img_w,
                (float(items[2]) + float(items[4]) / 2) * img_h
            ]
            boxes.append(box)
    return boxes


def load_detection_model():
    return AutoDetectionModel.from_pretrained(
        model_type=args.model_type,
        model_path=args.model_path,
        confidence_threshold=args.confidence_threshold,
        device=args.device
    )


def process_image(image_name, model, labels, detections):
    img_path = os.path.join(args.filepath, image_name)
    img_vis = cv2.imread(img_path)
    img_h, img_w, _ = img_vis.shape

    result = get_sliced_prediction(
        img_path,
        model,
        slice_height=args.slice_height,
        slice_width=args.slice_width,
        overlap_height_ratio=args.overlap_height_ratio,
        overlap_width_ratio=args.overlap_width_ratio,
        verbose = 0
    )

    anno_file = os.path.join(args.annotation_folder, image_name[:-4] + '.txt')
    annotations = read_boxes(anno_file, img_w, img_h)

    for anno in annotations:
        label, xmin_gt, ymin_gt, xmax_gt, ymax_gt = anno
        labels.append(BoundingBox.of_bbox(image_name, label, xmin_gt, ymin_gt, xmax_gt, ymax_gt))
        if args.visualize_annotations:
            cv2.rectangle(img_vis, (int(xmin_gt), int(ymin_gt)), (int(xmax_gt), int(ymax_gt)), get_color(label), 2,
                          cv2.LINE_AA)
            cv2.putText(img_vis, f"{args.class_list[label]}", (int(xmin_gt), int(ymin_gt - 5)),
                        cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(label), thickness=2)

    for pred in result.object_prediction_list:
        bbox = pred.bbox
        cls = pred.category.id
        score = pred.score.value
        xmin_pd, ymin_pd, xmax_pd, ymax_pd = bbox.minx, bbox.miny, bbox.maxx, bbox.maxy
        detections.append(BoundingBox.of_bbox(image_name, cls, xmin_pd, ymin_pd, xmax_pd, ymax_pd, score))

        if args.visualize_predictions:
            cv2.rectangle(img_vis, (int(xmin_pd), int(ymin_pd)), (int(xmax_pd), int(ymax_pd)),
                          get_color(cls + len(args.class_list)), 2, cv2.LINE_AA)
            cv2.putText(img_vis, f"{args.class_list[cls]} {score:.2f}", (int(xmin_pd), int(ymin_pd - 5)),
                        cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(cls + len(args.class_list)), thickness=2)

    if args.visualize_predictions or args.visualize_annotations:
        cv2.imshow(image_name, img_vis)
        cv2.waitKey(0)
        cv2.destroyAllWindows()


def evaluate_model(labels, detections):
    results = get_pascal_voc_metrics(labels, detections, 0.5)
    table = [
        [args.class_list[int(class_id)], m.recall[-1], m.precision[-1], m.ap]
        for class_id, m in results.items() if m.num_groundtruth > 0
    ]
    map_score = MetricPerClass.mAP(results)
    print(tabulate(table, headers=["ClassID", "Recall", "Precision", "AP"], floatfmt=".2f"))
    print(f"\nmAP: {map_score:.4f}")


def main():
    detection_model = load_detection_model()
    image_names = [name for name in os.listdir(args.filepath) if
                   os.path.splitext(name)[1].lower() in args.images_format]
    labels, detections = [], []

    for i, image_name in enumerate(tqdm(image_names, desc="Processing images")):
        process_image(image_name, detection_model, labels, detections)

    evaluate_model(labels, detections)


if __name__ == "__main__":
    main()

三、效果展示

计算评估指标

在这里插入图片描述

可视化推理结果

在这里插入图片描述

可视化标注结果

在这里插入图片描述

同时可视化推理结果和标注结果

在这里插入图片描述


总结

本期博客就到这里啦,喜欢的小伙伴们可以点点关注,感谢!

最近经常在b站上更新一些有关目标检测的视频,大家感兴趣可以来看看 https://b23.tv/1upjbcG

学习交流群:995760755

  • 27
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 14
    评论
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

[空--白]

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值