简介
本文将介绍如何使用 ONNX 进行 YOLOv8 Oriented Bounding Box (OBB) 推理。本例中,我们将使用 Python 编写的代码进行图像处理和对象检测,并展示如何加载模型、预处理图像、进行推理以及后处理结果。
代码
以下是实现 YOLOv8 OBB 推理的完整代码:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
# @FileName : YOLOv8_OBB.py
# @Time : 2024-07-25 17:33:48
# @Author : XuMing
# @Email : 920972751@qq.com
# @description : YOLOv8 Oriented Bounding Box Inference using ONNX
"""
import cv2
import math
import random
import numpy as np
import onnxruntime as ort
from loguru import logger
class RotatedBOX:
def __init__(self, box, score, class_index):
self.box = box
self.score = score
self.class_index = class_index
class ONNXInfer:
def __init__(self, onnx_model, class_names, device='auto', conf_thres=0.5, nms_thres=0.4) -> None:
self.onnx_model = onnx_model
self.class_names = class_names
self.conf_thres = conf_thres
self.nms_thres = nms_thres
self.device = self._select_device(device)
logger.info(f"Loading model on {self.device}...")
self.session_model = ort.InferenceSession(
self.onnx_model,
providers=self.device,
sess_options=self._get_session_options()
)
def _select_device(self, device):
"""
Select the appropriate device.
:param device: 'auto', 'cuda', or 'cpu'.
:return: List of providers.
"""
if device == 'cuda' or (device == 'auto' and ort.get_device() == 'GPU'):
return ['CUDAExecutionProvider', 'CPUExecutionProvider']
return ['CPUExecutionProvider']
def _get_session_options(self):
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
sess_options.intra_op_num_threads = 4
return sess_options
def preprocess(self, img):
"""
Preprocess the image for inference.
:param img: Input image.
:return: Preprocessed image blob, original image width, and original image height.
"""
logger.info(
"Preprocessing input image to [1, channels, input_w, input_h] format")
height, width = img.shape[:2]
length = max(height, width)
image = np.zeros((length, length, 3), np.uint8)
image[0:height, 0:width] = img
input_shape = self.session_model.get_inputs()[0].shape[2:]
logger.debug(f"Input shape: {input_shape}")
blob = cv2.dnn.blobFromImage(
image, scalefactor=1 / 255, size=tuple(input_shape), swapRB=True)
logger.info(f"Preprocessed image blob shape: {blob.shape}")
return blob, image, width, height
def predict(self, img):
"""
Perform inference on the image.
:param img: Input image.
:return: Inference results.
"""
blob, resized_image, orig_width, orig_height = self.preprocess(img)
inputs = {self.session_model.get_inputs()[0].name: blob}
try:
outputs = self.session_model.run(None, inputs)
except Exception as e:
logger.error(f"Inference failed: {e}")
raise
return self.postprocess(outputs, resized_image, orig_width, orig_height)
def postprocess(self, outputs, resized_image, orig_width, orig_height):
"""
Postprocess the model output.
:param outputs: Model outputs.
:param resized_image: Resized image used for inference.
:param orig_width: Original image width.
:param orig_height: Original image height.
:return: List of RotatedBOX objects.
"""
output_data = outputs[0]
logger.info(
f"Postprocessing output data with shape: {output_data.shape}")
input_shape = self.session_model.get_inputs()[0].shape[2:]
x_factor = resized_image.shape[1] / float(input_shape[1])
y_factor = resized_image.shape[0] / float(input_shape[0])
flattened_output = output_data.flatten()
reshaped_output = np.reshape(
flattened_output, (output_data.shape[1], output_data.shape[2])).T
detected_boxes = []
confidences = []
rotated_boxes = []
num_classes = len(self.class_names)
for detection in reshaped_output:
class_scores = detection[4:4 + num_classes]
class_id = np.argmax(class_scores)
confidence_score = class_scores[class_id]
if confidence_score > self.conf_thres:
cx, cy, width, height = detection[:4] * \
[x_factor, y_factor, x_factor, y_factor]
angle = detection[4 + num_classes]
if 0.5 * math.pi <= angle <= 0.75 * math.pi:
angle -= math.pi
box = ((cx, cy), (width, height), angle * 180 / math.pi)
rotated_box = RotatedBOX(box, confidence_score, class_id)
detected_boxes.append(cv2.boundingRect(cv2.boxPoints(box)))
rotated_boxes.append(rotated_box)
confidences.append(confidence_score)
nms_indices = cv2.dnn.NMSBoxes(
detected_boxes, confidences, self.conf_thres, self.nms_thres)
remain_boxes = [rotated_boxes[i] for i in nms_indices.flatten()]
logger.info(f"Detected {len(remain_boxes)} objects after NMS")
return remain_boxes
def generate_colors(self, num_classes):
"""
Generate a list of distinct colors for each class.
:param num_classes: Number of classes.
:return: List of RGB color tuples.
"""
colors = []
for _ in range(num_classes):
colors.append((random.randint(0, 255), random.randint(
0, 255), random.randint(0, 255)))
return colors
def drawshow(self, original_image, detected_boxes, class_labels):
"""
Draw detected bounding boxes and labels on the image and display it.
:param original_image: The input image on which to draw the boxes.
:param detected_boxes: List of detected RotatedBOX objects.
:param class_labels: List of class labels.
"""
# Generate random colors for each class
num_classes = len(class_labels)
colors = self.generate_colors(num_classes)
for detected_box in detected_boxes:
box = detected_box.box
points = cv2.boxPoints(box)
# Rescale the points back to the original image dimensions
points[:, 0] = points[:, 0]
points[:, 1] = points[:, 1]
points = np.int0(points)
class_id = detected_box.class_index
# Draw the bounding box with the color for the class
color = colors[class_id]
cv2.polylines(original_image, [points],
isClosed=True, color=color, thickness=2)
# Put the class label text with the same color
cv2.putText(original_image, class_labels[class_id], (points[0][0], points[0][1]),
cv2.FONT_HERSHEY_PLAIN, 1.0, color, 1)
# Display the image with drawn boxes
cv2.imshow("Detected Objects", original_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
if __name__ == "__main__":
img_path = "OIP-C.jpg"
model_path = "yolov8s-obb.onnx"
class_names = [
"plane", "ship", "storage tank", "baseball diamond", "tennis court",
"basketball court", "ground track field", "harbor", "bridge",
"large vehicle", "small vehicle", "helicopter", "roundabout",
"soccer ball field", "swimming pool"
]
img = cv2.imread(img_path)
if img is None:
logger.error(f"Failed to load image: {img_path}")
else:
app = ONNXInfer(onnx_model=model_path, class_names=class_names)
predictions = app.predict(img)
# logger.info(f"Inference results: {predictions}")
app.drawshow(img, predictions, class_names)