yolov8或yolov5模型转onnx

文章目录

简介

yolov8或yolov5训练模型,模型转onnx,onnx接口封装

yolov8

训练

yolo train model=yolov8n.pt data=/workspace/data/car_door/data.yaml device=0,1 epochs=2 batch=8 workers=8 imgsz=1280 half=True

转onnx,模型在训练命令执行目录下/runs/detect/

yolo export model=best.pt format=onnx

安装 onnx

#以下版本选择一个,运行机器只需要安装onnx依赖,不需要安装yolov8
#gpu版本
pip install onnxruntime-gpu -i https://pypi.tuna.tsinghua.edu.cn/simple
#cpu版本
pip install onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple

封装

# -*- coding: utf-8 -*-
# +
import cv2
import numpy as np
import onnxruntime as ort
#import torch
#from ultralytics.utils.checks import check_requirements
import time

class YOLOv8:
    """YOLOv8 object detection model class for handling inference and visualization."""

    def __init__(self, onnx_model, confidence_thres, iou_thres, classes):
        """
        Initializes an instance of the YOLOv8 class.

        Args:
            onnx_model: Path to the ONNX model.
            input_image: Path to the input image.
            confidence_thres: Confidence threshold for filtering detections.
            iou_thres: IoU (Intersection over Union) threshold for non-maximum suppression.
        """
        self.onnx_model = onnx_model
        self.confidence_thres = confidence_thres
        self.iou_thres = iou_thres

        # Load the class names from the COCO dataset
        self.classes = classes

        # Generate a color palette for the classes
        self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
        # Create an inference session using the ONNX model and specify execution providers
        self.session = ort.InferenceSession(self.onnx_model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])

    def draw_detections(self, img, box, score, class_id):
        """
        Draws bounding boxes and labels on the input image based on the detected objects.

        Args:
            img: The input image to draw detections on.
            box: Detected bounding box.
            score: Corresponding detection score.
            class_id: Class ID for the detected object.

        Returns:
            None
        """

        # Extract the coordinates of the bounding box
        x1, y1, w, h = box

        # Retrieve the color for the class ID
        color = self.color_palette[class_id]

        # Draw the bounding box on the image
        cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)

        # Create the label text with class name and score
        label = f"{self.classes[class_id]}: {score:.2f}"

        # Calculate the dimensions of the label text
        (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)

        # Calculate the position of the label text
        label_x = x1
        label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10

        # Draw a filled rectangle as the background for the label text
        cv2.rectangle(
            img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED
        )

        # Draw the label text on the image
        cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)

    def preprocess(self):
        """
        Preprocesses the input image before performing inference.

        Returns:
            image_data: Preprocessed image data ready for inference.
        """
        # Get the height and width of the input image
        self.img_height, self.img_width = self.img.shape[:2]

        # Convert the image color space from BGR to RGB
        img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)

        # Resize the image to match the input shape
        img = cv2.resize(img, (self.input_width, self.input_height))

        # Normalize the image data by dividing it by 255.0
        image_data = np.array(img) / 255.0

        # Transpose the image to have the channel dimension as the first dimension
        image_data = np.transpose(image_data, (2, 0, 1))  # Channel first

        # Expand the dimensions of the image data to match the expected input shape
        image_data = np.expand_dims(image_data, axis=0).astype(np.float32)

        # Return the preprocessed image data
        return image_data

    def postprocess(self, img, output):
        """
        Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs.

        Args:
            input_image (numpy.ndarray): The input image.
            output (numpy.ndarray): The output of the model.

        Returns:
            numpy.ndarray: The input image with detections drawn on it.
        """

        # Transpose and squeeze the output to match the expected shape
        outputs = np.transpose(np.squeeze(output[0]))

        # Get the number of rows in the outputs array
        rows = outputs.shape[0]

        # Lists to store the bounding boxes, scores, and class IDs of the detections
        boxes = []
        scores = []
        class_ids = []

        # Calculate the scaling factors for the bounding box coordinates
        x_factor = self.img_width / self.input_width
        y_factor = self.img_height / self.input_height

        # Iterate over each row in the outputs array
        for i in range(rows):
            # Extract the class scores from the current row
            classes_scores = outputs[i][4:]

            # Find the maximum score among the class scores
            max_score = np.amax(classes_scores)

            # If the maximum score is above the confidence threshold
            if max_score >= self.confidence_thres:
                # Get the class ID with the highest score
                class_id = np.argmax(classes_scores)

                # Extract the bounding box coordinates from the current row
                x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]

                # Calculate the scaled coordinates of the bounding box
                left = int((x - w / 2) * x_factor)
                top = int((y - h / 2) * y_factor)
                width = int(w * x_factor)
                height = int(h * y_factor)

                # Add the class ID, score, and box coordinates to the respective lists
                class_ids.append(class_id)
                scores.append(max_score)
                boxes.append([left, top, width, height])

        # Apply non-maximum suppression to filter out overlapping bounding boxes
        indices = cv2.dnn.NMSBoxes(boxes, scores, self.confidence_thres, self.iou_thres)

        #用于存放结果
        detections=[]
        
        # print("indices",indices)
        # Iterate over the selected indices after non-maximum suppression
        for i in indices:
            # Get the box, score, and class ID corresponding to the index
            box = boxes[i]
            score = scores[i]
            class_id = class_ids[i]
            print("type(score):",type(score))
            print("type(box)",type(box))
            detections.append({'class': self.classes[class_id], 'conf': float(score), 'position': box, 'index': int(class_id)})
            # Draw the detection on the input image
            self.draw_detections(img, box, score, class_id)
        print("detections:",detections)

        # Return the modified input image
        return img,detections

    def detect(self, img):
        """
        Performs inference using an ONNX model and returns the output image with drawn detections.

        Returns:
            output_img: The output image with drawn detections.
        """
        # time_start = time.time()
        self.img = img

        # print('time cos detect 1:', time.time()-time_start, 's')
        # Get the model inputs
        model_inputs = self.session.get_inputs()

        # print('time cos detect 2:', time.time()-time_start, 's')
        # Store the shape of the input for later use
        input_shape = model_inputs[0].shape
        self.input_width = input_shape[2]
        self.input_height = input_shape[3]

        # print('time cos detect 3:', time.time()-time_start, 's')
        # Preprocess the image data
        img_data = self.preprocess()

        # print('time cos detect 4:', time.time()-time_start, 's')
        # Run inference using the preprocessed image data
        outputs = self.session.run(None, {model_inputs[0].name: img_data})

        # print('time cos detect 5:', time.time()-time_start, 's')
        # Perform post-processing on the outputs to obtain output image.
        return self.postprocess(self.img, outputs)  # output image

def main():
    time_start = time.time()
    # Check the requirements and select the appropriate backend (CPU or GPU)
    #check_requirements("onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime")

    # Create an instance of the YOLOv8 class with the specified arguments
    yolo = YOLOv8(onnx_model= "model/best.onnx", confidence_thres = 0.5, iou_thres = 0, classes = ['line','screw_3','block_1','block_2','screw_1','screw_2','wang','solder_joint_1','screw_1_ng','screw_2_ng'])

    # Perform object detection and obtain the output image
    img_path = "/home/wai/hik/code/box-end/app/history/label_images_b1/images/image-200002-2-2024-01-06-14-55-51-705560.jpg"
    img = cv2.imread(img_path)
    output_image = yolo.detect(img)

    print('time cos detect:', time.time()-time_start, 's')

    # # Display the output image in a window
    # cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
    # cv2.imshow("Output", output_image)

    # # Wait for a key press to exit
    # cv2.waitKey(0)
main()

yolov5

转onnx

python export.py --weights yolov5s.pt --img 640 --batch 1

封装

class YOLOV5():
    def __init__(self,onnxpath,model_conf_thres,model_iou_thres,classes):
        self.onnx_session=ort.InferenceSession(onnxpath)
        self.input_name=self.get_input_name()
        self.output_name=self.get_output_name()
        print("input_name:",self.input_name,"output_name:",self.output_name)
        self.model_conf_thres=model_conf_thres
        self.model_iou_thres=model_iou_thres
        self.classes=classes
        
        # Generate a color palette for the classes
        #self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
        hexs = ('2f4f4f', '191970', '006400', '00ced1', 'ffa500', 'ffff00', '00ff00', '00fa9a', '0000ff', 'da70d6', 'd8bfd8', 'ff00ff', '1e90ff', 'f0e68c')
        palette = [self.hex2rgb(f'#{c}') for c in hexs]
        palette_c = []
        for p in palette:
            c = p
            palette_c.append([c[2], c[1], c[0]])
        self.color_palette =palette_c
        print("color_palette:",self.color_palette)

    @staticmethod
    def hex2rgb(h):  # rgb order (PIL)
        return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))

    #-------------------------------------------------------
    #   获取输入输出的名字
    #-------------------------------------------------------
    def get_input_name(self):
        input_name=[]
        for node in self.onnx_session.get_inputs():
            input_name.append(node.name)
        return input_name
    def get_output_name(self):
        output_name=[]
        for node in self.onnx_session.get_outputs():
            output_name.append(node.name)
        return output_name
    #-------------------------------------------------------
    #   输入图像
    #-------------------------------------------------------
    def get_input_feed(self,img_tensor):
        input_feed={}
        for name in self.input_name:
            input_feed[name]=img_tensor
        return input_feed

    #-------------------------------------------------------
    #   1.cv2读取图像并resize
    #	2.图像转BGR2RGB和HWC2CHW
    #	3.图像归一化
    #	4.图像增加维度
    #	5.onnx_session 推理
    #-------------------------------------------------------
    def detect(self,or_img):
        self.img_height,self.img_width = or_img.shape[0:2]
        # Get the model inputs
        model_inputs = self.onnx_session.get_inputs()
        # Store the shape of the input for later use
        input_shape = model_inputs[0].shape
        print("input_shape:",input_shape)
        self.input_width = input_shape[2]
        self.input_height = input_shape[3]
        or_img=cv2.resize(or_img,(self.input_width, self.input_height))
        
        img=or_img[:,:,::-1].transpose(2,0,1)  #BGR2RGB和HWC2CHW
        img=img.astype(dtype=np.float32)
        img/=255.0
        img=np.expand_dims(img,axis=0)
        
        input_feed=self.get_input_feed(img)
        pred=self.onnx_session.run(None,input_feed)[0]

        pred=self.filter_box(pred, self.model_conf_thres, self.model_iou_thres)
        detections = self.draw(or_img,pred)
        return or_img,detections

    #dets:  array [x,6] 6个值分别为x1,y1,x2,y2,score,class 
    #thresh: 阈值
    def nms(self, dets, thresh):
        x1 = dets[:, 0]
        y1 = dets[:, 1]
        x2 = dets[:, 2]
        y2 = dets[:, 3]
        #-------------------------------------------------------
        #   计算框的面积
        #	置信度从大到小排序
        #-------------------------------------------------------
        areas = (y2 - y1 + 1) * (x2 - x1 + 1)
        scores = dets[:, 4]
        keep = []
        index = scores.argsort()[::-1] 

        while index.size > 0:
            i = index[0]
            keep.append(i)
            #-------------------------------------------------------
            #   计算相交面积
            #	1.相交
            #	2.不相交
            #-------------------------------------------------------
            x11 = np.maximum(x1[i], x1[index[1:]]) 
            y11 = np.maximum(y1[i], y1[index[1:]])
            x22 = np.minimum(x2[i], x2[index[1:]])
            y22 = np.minimum(y2[i], y2[index[1:]])

            w = np.maximum(0, x22 - x11 + 1)                              
            h = np.maximum(0, y22 - y11 + 1) 

            overlaps = w * h
            #-------------------------------------------------------
            #   计算该框与其它框的IOU,去除掉重复的框,即IOU值大的框
            #	IOU小于thresh的框保留下来
            #-------------------------------------------------------
            ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
            idx = np.where(ious <= thresh)[0]
            index = index[idx + 1]
        return keep

    def xywh2xyxy(self,x):
        # [x, y, w, h] to [x1, y1, x2, y2]
        y = np.copy(x)
        y[:, 0] = x[:, 0] - x[:, 2] / 2
        y[:, 1] = x[:, 1] - x[:, 3] / 2
        y[:, 2] = x[:, 0] + x[:, 2] / 2
        y[:, 3] = x[:, 1] + x[:, 3] / 2
        return y

    def filter_box(self,org_box,conf_thres,iou_thres): #过滤掉无用的框
        #-------------------------------------------------------
        #   删除为1的维度
        #	删除置信度小于conf_thres的BOX
        #-------------------------------------------------------
        org_box=np.squeeze(org_box)

        conf = org_box[..., 4] > conf_thres
        box = org_box[conf == True]
        #-------------------------------------------------------
        #	通过argmax获取置信度最大的类别
        #-------------------------------------------------------
        cls_cinf = box[..., 5:]
        cls = []
        for i in range(len(cls_cinf)):
            cls.append(int(np.argmax(cls_cinf[i])))
        all_cls = list(set(cls))     
        #-------------------------------------------------------
        #   分别对每个类别进行过滤
        #	1.将第6列元素替换为类别下标
        #	2.xywh2xyxy 坐标转换
        #	3.经过非极大抑制后输出的BOX下标
        #	4.利用下标取出非极大抑制后的BOX
        #-------------------------------------------------------
        output = []
        for i in range(len(all_cls)):
            curr_cls = all_cls[i]
            curr_cls_box = []
            curr_out_box = []
            for j in range(len(cls)):
                if cls[j] == curr_cls:
                    box[j][5] = curr_cls
                    curr_cls_box.append(box[j][:6])
            curr_cls_box = np.array(curr_cls_box)
            # curr_cls_box_old = np.copy(curr_cls_box)
            curr_cls_box = self.xywh2xyxy(curr_cls_box)
            curr_out_box = self.nms(curr_cls_box,iou_thres)
            for k in curr_out_box:
                output.append(curr_cls_box[k])
        output = np.array(output)
        return output

    def draw(self,image,box_data):  
        #-------------------------------------------------------
        #	取整,方便画框
        #-------------------------------------------------------
        boxes=box_data[...,:4].astype(np.int32) 
        scores=box_data[...,4]
        classes=box_data[...,5].astype(np.int32) 

        x_factor = self.img_width / self.input_width
        y_factor = self.img_height / self.input_height

        print("img_width,img_height,self.input_width,self.input_height,x_factor,y_factor",self.img_width,self.img_height,self.input_width,self.input_height,x_factor,y_factor)

        detections=[]
        for box, score, cl in zip(boxes, scores, classes):
            x1, y1, x2, y2 = box
            x1 = int(x1 * x_factor)
            y1 = int(y1 * y_factor)
            x2 = int(x2 * x_factor)
            y2 = int(y2 * y_factor)
#             print('class: {}, score: {}'.format(self.classes[cl], score))
#             print('box coordinate left,top,right,down: [{}, {}, {}, {}]'.format(x1, y1, x2, y2))

#             cv2.rectangle(image, (top, left), (right, bottom), (255, 0, 0), 2)
#             cv2.putText(image, '{0} {1:.2f}'.format(self.classes[cl], score),
#                         (top, left ),
#                         cv2.FONT_HERSHEY_SIMPLEX,
#                         0.6, (0, 0, 255), 2)
            x1y1wh = [x1, y1, x2-x1, y2-y1]
            detections.append({'class': self.classes[cl], 'conf': float(score), 'position': x1y1wh, 'index': int(cl)})
            self.draw_detections(image, x1y1wh, score, cl, self.classes[cl])
        return detections

    def draw_detections(self, img, box, score, class_id, label):
        """
        Draws bounding boxes and labels on the input image based on the detected objects.

        Args:
            img: The input image to draw detections on.
            box: Detected bounding box.
            score: Corresponding detection score.
            class_id: Class ID for the detected object.

        Returns:
            None
        """

        # Extract the coordinates of the bounding box
        x1, y1, w, h = box

        # Retrieve the color for the class ID
        color = self.color_palette[class_id]

        # Draw the bounding box on the image
        cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)

        # Create the label text with class name and score
        label = f"{self.classes[class_id]}: {score:.2f}"

        # Calculate the dimensions of the label text
        (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)

        # Calculate the position of the label text
        label_x = x1
        label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10

        # Draw a filled rectangle as the background for the label text
        cv2.rectangle(
            img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED
        )

        # Draw the label text on the image
        cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)

def yolov5_main():
    img_path = "/workspace/data/car_door/org/car_door_component/images/image-200002-2-2024-01-04-20-44-12-682811.jpg"
    img=cv2.imread(img_path)

    yolo=YOLOV5(onnxpath="/workspace/doc/pytorch/yolov5/runs/train/exp3/weights/best.onnx",
        model_conf_thres=0.5,
        model_iou_thres=0,
        classes=['line','screw_3','block_1','block_2','screw_1','screw_2','wang','solder_joint_1','screw_1_ng','screw_2_ng'])

    img,detections=yolo.detect(img)
    for detection in detections:
        print(detection)
    cv2.imwrite('yolov5-onnx-result.jpg',img)

    cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
    cv2.imshow("Output", img)
    cv2.waitKey(0)

yolov5_main()
  • 11
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

绯虹剑心

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值