Yolov5封装detect.py面向对象

本文介绍了一个基于YOLOv5的实时检测系统,使用Websocket客户端将检测到的对象信息打包发送给服务器,适用于摄像头RTSP流的监控应用。
摘要由CSDN通过智能技术生成

主要目标是适应摄像头rtsp流的检测

如果是普通文件夹或者图片,run中的while True去掉即可。

web_client是根据需求创建的客户端,将检测到的数据打包发送给服务器

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Run inference on images, videos, directories, streams, etc.

Usage:
    $ python path/to/detect.py --source path/to/img.jpg --weights yolov5s.pt --img 640
"""

import argparse
import json
import os
import sys
import time
import moment
from pathlib import Path

import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative

from models.experimental import attempt_load
from utils.datasets import LoadImages, LoadStreams
from utils.general import apply_classifier, check_img_size, check_imshow, check_requirements, check_suffix, colorstr, \
    increment_path, non_max_suppression, print_args, save_one_box, scale_coords, set_logging, \
    strip_optimizer, xyxy2xywh
from utils.plots import Annotator, colors
from utils.torch_utils import load_classifier, select_device, time_sync

from mytools import read_yaml_all, base64_encode_img
from message_base import MessageBase
from websocket_client import WebClient


class Detect:
    def __init__(self, config: dict, client: WebClient):
        self.config = config
        self.weights = self.config.get("weights")  # weights path
        self.source = self.config.get("source")  # source 
        self.imgsz = self.config.get("imgsz")  # imgsz
        self.conf_thres = self.config.get("conf_thres")
        self.iou_thres = self.config.get("iou_thres")
        self.max_det = self.config.get("max_det")
        self.device = self.config.get("device")  # "cpu" or "0,1,2,3"
        self.view_img = self.config.get("view_img")  # show results
        self.save_txt = self.config.get("save_txt")  # save results to *.txt
        self.save_conf = self.config.get("save_conf")  # save confidences in --save-txt labels
        self.save_crop = self.config.get("save_crop")  # save cropped prediction boxes
        self.nosave = self.config.get("nosave")  # do not save images/videos
        self.classes = self.config.get("classes")  # filter by class: --class 0, or --class 0 2 3
        self.agnostic_nms = self.config.get("agnostic_nms")  # class-agnostic NMS
        self.augment = self.config.get("augment")  # augmented inference
        self.visualize = self.config.get("visualize")  # visualize features
        self.update = self.config.get("update")  # update all models
        self.save_path = self.config.get("save_path")  # save results to project/name
        self.line_thickness = self.config.get("line_thickness")  # bounding box thickness (pixels)
        self.hide_labels = self.config.get("hide_labels")  # hide labels
        self.hide_conf = self.config.get("hide_conf")  # hide confidences
        self.half = self.config.get("half")  # use FP16 half-precision inference
        self.dnn = self.config.get("dnn")  # use OpenCV DNN for ONNX inference
        self.func_device = self.config.get("func_device")  # 对应功能的设备名字
        self.save_img = not self.nosave and not self.source.endswith('.txt')  # save inference images
        self.webcam = self.source.isnumeric() or self.source.endswith('.txt') or self.source.lower().startswith(
            ('rtsp://', 'rtmp://', 'http://', 'https://'))
        set_logging()
        self.device = select_device(self.device)
        self.half = self.device.type != 'cpu'  # half precision only supported on CUDA
        self.model = attempt_load(self.weights, map_location=self.device)
        self.imgsz = check_img_size(self.imgsz, s=int(self.model.stride.max()))
        self.stride = int(self.model.stride.max())
        self.names = self.model.module.names if hasattr(
            self.model, 'module') else self.model.names
        # 获取数据
        if self.webcam:
            self.view_img = check_imshow()
            cudnn.benchmark = True  # set True to speed up constant image size inference
            self.dataset = LoadStreams(self.source, img_size=self.imgsz, stride=self.stride, auto=True)
            self.bs = len(self.dataset)  # batch_size
        else:
            self.dataset = LoadImages(self.source, img_size=self.imgsz, stride=self.stride, auto=True)
            self.bs = 1  # batch_size

        self.client = client  # 客户端
        self.last_time = moment.now()
        self.check_time_step = 5  # 每隔多少时间检测一次
        os.mkdir(self.save_path) if not os.path.exists(self.save_path) else None

    def inference(self, img):
        img = torch.from_numpy(img).to(self.device)
        img = img.half() if self.half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)
        pred = self.model(img, augment=self.augment)[0]
        # NMS
        pred = non_max_suppression(pred, self.conf_thres, self.iou_thres,
                                   self.classes, self.agnostic_nms, max_det=self.max_det)
        return pred

    def process(self, im0s, img, pred, path):
        for i, det in enumerate(pred):  # per image
            if self.webcam:  # batch_size >= 1
                p, s, im0, frame = path[i], f'{i}: ', im0s[i].copy(), self.dataset.count
            else:
                p, s, im0, frame = path, '', im0s.copy(), getattr(self.dataset, 'frame', 0)

            p = Path(p)  # to Path
            txt_path = str(self.save_path + "/" + 'labels' + "/" + p.stem) + (
                '' if self.dataset.mode == 'image' else f'_{frame}')  # img.txt
            s += '%gx%g ' % img.shape[2:]  # print string
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            imc = im0.copy() if self.save_crop else im0  # for save_crop
            annotator = Annotator(im0, line_width=self.line_thickness, example=str(self.names))
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "  # add to string

                # Write results
                for *xyxy, conf, cls in reversed(det):
                    c = int(cls)
                    label = self.names[c]
                    # if label == "person":
                    if label:  # 根据对应标签做处理
                        # annotator.box_label(xyxy, label, color=colors(c, True)) # 画框
                        t = int(time.time())
                        img_path = f"{self.save_path}/{self.func_device}_{label}_{t}.jpg"
                        crop = save_one_box(xyxy, imc, img_path, BGR=True)
                        x1, y1, x2, y2 = int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])
                        data = {
                            "device": self.func_device,
                            "value": {
                                "label": label,
                                "time": t,
                                "locate": (x1, y1, x2, y2),
                                "crop": base64_encode_img(crop)
                            }
                        }
                        data = json.dumps(data)  # 打包数据
                        try:
                            self.client.send(data)  # 客户端发送数据
                            pass
                        except Exception as err:
                            print("发送失败:", err)
                            self.client.connect()
                            self.client.send(data)
                            print("重连成功!")
                        print(data)
                    # if self.save_txt:  # Write to file
                    #     xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(
                    #         -1).tolist()  # normalized xywh
                    #     line = (cls, *xywh, conf) if self.save_conf else (cls, *xywh)  # label format
                    #     with open(txt_path + '.txt', 'a') as f:
                    #         f.write(('%g ' * len(line)).rstrip() % line + '\n')
                    # 画框
                    # if self.save_img or self.save_crop or self.view_img:  # Add bbox to image
                    #     c = int(cls)  # integer class
                    #     label = None if self.hide_labels else (self.names[c] if self.hide_conf else
                    #                                            f'{self.names[c]} {conf:.2f}')
                    #     annotator.box_label(xyxy, label, color=colors(c, True))

    def run(self):
        self.client.connect()
        while True:
            for path, img, im0s, vid_cap in self.dataset:
                if self.last_time.__lt__(moment.now()):
                    self.last_time = moment.now().add(seconds=self.check_time_step)
                    try:
                        pred = self.inference(img)
                        self.process(im0s, img, pred, path)              
                    except Exception as err:
                        print(err)

            if self.save_txt or self.save_img:
                s = f"\n{len(list(self.save_path.glob('labels/*.txt')))} labels saved to {self.save_path / 'labels'}" if self.save_txt else ''
                print(f"Results saved to {colorstr('bold', self.save_path)}{s}")
            if self.update:
                strip_optimizer(self.weights)  # update model (to fix SourceChangeWarning)


if __name__ == "__main__":
    message_base = MessageBase()
    wc = WebClient("192.168.6.28", 8000)
    configs = read_yaml_all("yolo_configs.yaml")
    config = read_yaml_all("configs.yaml")
    device_name = config.get("DEVICE_LIST")[0]
    device_source = config.get("RTSP_URLS").get(device_name)
    configs["source"] = device_source
    configs["func_device"] = device_name
    print(configs)
    detect = Detect(configs, wc)
    detect.run()
  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值