YOLOv8-OBB ONNRuntime推理部署

简介

本文将介绍如何使用 ONNX 进行 YOLOv8 Oriented Bounding Box (OBB) 推理。本例中,我们将使用 Python 编写的代码进行图像处理和对象检测,并展示如何加载模型、预处理图像、进行推理以及后处理结果。

代码

以下是实现 YOLOv8 OBB 推理的完整代码:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
# @FileName      : YOLOv8_OBB.py
# @Time          : 2024-07-25 17:33:48
# @Author        : XuMing
# @Email         : 920972751@qq.com
# @description   : YOLOv8 Oriented Bounding Box Inference using ONNX
"""
import cv2
import math
import random
import numpy as np
import onnxruntime as ort
from loguru import logger

class RotatedBOX:
    def __init__(self, box, score, class_index):
        self.box = box
        self.score = score
        self.class_index = class_index

class ONNXInfer:
    def __init__(self, onnx_model, class_names, device='auto', conf_thres=0.5, nms_thres=0.4) -> None:
        self.onnx_model = onnx_model
        self.class_names = class_names
        self.conf_thres = conf_thres
        self.nms_thres = nms_thres
        self.device = self._select_device(device)

        logger.info(f"Loading model on {self.device}...")
        self.session_model = ort.InferenceSession(
            self.onnx_model,
            providers=self.device,
            sess_options=self._get_session_options()
        )

    def _select_device(self, device):
        """
        Select the appropriate device.
        :param device: 'auto', 'cuda', or 'cpu'.
        :return: List of providers.
        """
        if device == 'cuda' or (device == 'auto' and ort.get_device() == 'GPU'):
            return ['CUDAExecutionProvider', 'CPUExecutionProvider']
        return ['CPUExecutionProvider']

    def _get_session_options(self):
        sess_options = ort.SessionOptions()
        sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
        sess_options.intra_op_num_threads = 4
        return sess_options

    def preprocess(self, img):
        """
        Preprocess the image for inference.
        :param img: Input image.
        :return: Preprocessed image blob, original image width, and original image height.
        """
        logger.info(
            "Preprocessing input image to [1, channels, input_w, input_h] format")
        height, width = img.shape[:2]
        length = max(height, width)
        image = np.zeros((length, length, 3), np.uint8)
        image[0:height, 0:width] = img

        input_shape = self.session_model.get_inputs()[0].shape[2:]
        logger.debug(f"Input shape: {input_shape}")

        blob = cv2.dnn.blobFromImage(
            image, scalefactor=1 / 255, size=tuple(input_shape), swapRB=True)
        logger.info(f"Preprocessed image blob shape: {blob.shape}")

        return blob, image, width, height

    def predict(self, img):
        """
        Perform inference on the image.
        :param img: Input image.
        :return: Inference results.
        """
        blob, resized_image, orig_width, orig_height = self.preprocess(img)
        inputs = {self.session_model.get_inputs()[0].name: blob}
        try:
            outputs = self.session_model.run(None, inputs)
        except Exception as e:
            logger.error(f"Inference failed: {e}")
            raise
        return self.postprocess(outputs, resized_image, orig_width, orig_height)

    def postprocess(self, outputs, resized_image, orig_width, orig_height):
        """
        Postprocess the model output.
        :param outputs: Model outputs.
        :param resized_image: Resized image used for inference.
        :param orig_width: Original image width.
        :param orig_height: Original image height.
        :return: List of RotatedBOX objects.
        """
        output_data = outputs[0]
        logger.info(
            f"Postprocessing output data with shape: {output_data.shape}")

        input_shape = self.session_model.get_inputs()[0].shape[2:]
        x_factor = resized_image.shape[1] / float(input_shape[1])
        y_factor = resized_image.shape[0] / float(input_shape[0])

        flattened_output = output_data.flatten()
        reshaped_output = np.reshape(
            flattened_output, (output_data.shape[1], output_data.shape[2])).T

        detected_boxes = []
        confidences = []
        rotated_boxes = []

        num_classes = len(self.class_names)

        for detection in reshaped_output:
            class_scores = detection[4:4 + num_classes]
            class_id = np.argmax(class_scores)
            confidence_score = class_scores[class_id]

            if confidence_score > self.conf_thres:
                cx, cy, width, height = detection[:4] * \
                    [x_factor, y_factor, x_factor, y_factor]
                angle = detection[4 + num_classes]

                if 0.5 * math.pi <= angle <= 0.75 * math.pi:
                    angle -= math.pi

                box = ((cx, cy), (width, height), angle * 180 / math.pi)
                rotated_box = RotatedBOX(box, confidence_score, class_id)

                detected_boxes.append(cv2.boundingRect(cv2.boxPoints(box)))
                rotated_boxes.append(rotated_box)
                confidences.append(confidence_score)

        nms_indices = cv2.dnn.NMSBoxes(
            detected_boxes, confidences, self.conf_thres, self.nms_thres)
        remain_boxes = [rotated_boxes[i] for i in nms_indices.flatten()]

        logger.info(f"Detected {len(remain_boxes)} objects after NMS")
        return remain_boxes

    def generate_colors(self, num_classes):
        """
        Generate a list of distinct colors for each class.

        :param num_classes: Number of classes.
        :return: List of RGB color tuples.
        """
        colors = []
        for _ in range(num_classes):
            colors.append((random.randint(0, 255), random.randint(
                0, 255), random.randint(0, 255)))
        return colors

    def drawshow(self, original_image, detected_boxes, class_labels):
        """
        Draw detected bounding boxes and labels on the image and display it.

        :param original_image: The input image on which to draw the boxes.
        :param detected_boxes: List of detected RotatedBOX objects.
        :param class_labels: List of class labels.
        """
        # Generate random colors for each class
        num_classes = len(class_labels)
        colors = self.generate_colors(num_classes)

        for detected_box in detected_boxes:
            box = detected_box.box
            points = cv2.boxPoints(box)

            # Rescale the points back to the original image dimensions
            points[:, 0] = points[:, 0]
            points[:, 1] = points[:, 1]
            points = np.int0(points)

            class_id = detected_box.class_index

            # Draw the bounding box with the color for the class
            color = colors[class_id]
            cv2.polylines(original_image, [points],
                          isClosed=True, color=color, thickness=2)
            # Put the class label text with the same color
            cv2.putText(original_image, class_labels[class_id], (points[0][0], points[0][1]),
                        cv2.FONT_HERSHEY_PLAIN, 1.0, color, 1)

        # Display the image with drawn boxes
        cv2.imshow("Detected Objects", original_image)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

if __name__ == "__main__":
    img_path = "OIP-C.jpg"
    model_path = "yolov8s-obb.onnx"
    class_names = [
        "plane", "ship", "storage tank", "baseball diamond", "tennis court",
        "basketball court", "ground track field", "harbor", "bridge",
        "large vehicle", "small vehicle", "helicopter", "roundabout",
        "soccer ball field", "swimming pool"
    ]

    img = cv2.imread(img_path)
    if img is None:
        logger.error(f"Failed to load image: {img_path}")
    else:
        app = ONNXInfer(onnx_model=model_path, class_names=class_names)
        predictions = app.predict(img)
        # logger.info(f"Inference results: {predictions}")
        app.drawshow(img, predictions, class_names)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值