Yolov5 detect.py单张图片检测代码简化

         Yolov5官方下载的代码里,含有detect.py,可用于使用训练好的权重对图片、视频进行预测,但是detect.py里集成了太多功能,如果仅仅是想使用训练好的权重对一张图片进行预测,这个代码就显得太冗杂了。

        将检测一张图片的代码写成一个类,类里有一个detect_image()函数,返回检测到的目标的坐标。

import torch
import sys
sys.path.append(r"./yolov5-6.1")

from models.common import DetectMultiBackend
from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
                           increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.torch_utils import select_device, time_sync


class predict:
    def __init__(self):
        self.weights = "./best.pt"  # model.pt path(s)
        self.data = "./yolov5-6.1/Clean_data/my_coco128.yaml"  # dataset.yaml path
        self.imgsz = (640, 640)  # inference size (height, width)
        self.conf_thres = 0.5  # confidence threshold
        self.iou_thres = 0.45  # NMS IOU threshold
        self.max_det = 1000  # maximum detections per image
        self.classes = None # filter by class: --class 0, or --class 0 2 3
        self.agnostic_nms = False  # class-agnostic NMS
        self.augment = False  # augmented inference
        self.visualize = False  # visualize features
        self.half = False
        self.dnn = True  # use OpenCV DNN for ONNX inference

        # Load model
        self.device = select_device('cpu')
        self.model = DetectMultiBackend(self.weights, device=self.device, dnn=self.dnn, data=self.data)
        self.stride, self.names, self.pt, self.jit, self.onnx, self.engine = self.model.stride, self.model.names, self.model.pt, self.model.jit, self.model.onnx, self.model.engine
        self.imgsz = check_img_size(self.imgsz, s=self.stride)  # check image size
        # Half
        self.half = False  # FP16 supported on limited backends with CUDA


    def detect_image(self, image_path):
        det_list = []
        # Dataloader
        dataset = LoadImages(image_path, img_size=self.imgsz, stride=self.stride, auto=self.pt)
        # Run inference
        for path, im, im0s, vid_cap, s in dataset:
            im0 = im0s.copy()
            im = torch.from_numpy(im).to(self.device)
            im = im.float()  # uint8 to fp16/32
            im /= 255  # 0 - 255 to 0.0 - 1.0
            if len(im.shape) == 3:
                im = im[None]  # expand for batch dim
           # Inference
            pred = self.model(im, augment=self.augment, visualize=self.visualize)
            # NMS
            pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, self.classes, self.agnostic_nms, max_det=self.max_det)
            # Process predictions
            for i, det in enumerate(pred):  # per image
                if len(det):
                    # Rescale boxes from img_size to im0 size
                    det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
                # 转换为[[ymin, xmin, ymax, xmax, conf, class_pred],[ymin, xmin, ymax, xmax, conf, class_pred]]
                det_list = [[box[1].item(), box[0].item(),
                             box[3].item(), box[2].item(), box[4].item()
                                , int(box[5].item())] for box in det]
        return det_list

  • 8
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值