yolo系列简化detect函数

yolo系列简化detect函数
1、封装类

import os.path
import argparse
import numpy as np
import torch
from numpy import random

from models.experimental import attempt_load
from utils.datasets import letterbox
from utils.general import non_max_suppression, scale_coords, xyxy2xywh, check_img_size
from utils.plots import plot_one_box
from utils.torch_utils import time_synchronized


class YOLO(object):
    _defaults = {
        # ----------------------------------------------------------------------------#
        # 事先训练完成的权重文件,比如yolov5s.pt,假如使用官方训练好的文件(比如yolov5s),则会自动下载
        # ----------------------------------------------------------------------------#
        "weights": "weights/yolov7.pt",
        # ----------------------------------------------------------------------------#
        # 预测时的放缩后图片大小(因为YOLO算法需要预先放缩图片)
        # ----------------------------------------------------------------------------#
        "imgsz": 640,
        # ----------------------------------------------------------------------------#
        # use FP16 half-precision inference 是否使用半精度推理(节约显存)
        "half": False,
        # ----------------------------------------------------------------------------#
        # 置信度阈值, 高于此值的bounding_box才会被保留
        # ----------------------------------------------------------------------------#
        "conf_thres": 0.25,
        # ----------------------------------------------------------------------------#
        # IOU阈值,高于此值的bounding_box才会被保留
        # ----------------------------------------------------------------------------#
        "iou_thres": 0.45,
        # ----------------------------------------------------------------------------#
        # augmented inference
        # ----------------------------------------------------------------------------#
        "augment": None,
        # ----------------------------------------------------------------------------#
        # 过滤指定类的预测结果
        # ----------------------------------------------------------------------------#
        "classes": None,
        # ----------------------------------------------------------------------------#
        # 如为True,则为class-agnostic. 否则为class-specific
        # ----------------------------------------------------------------------------#
        "agnostic_nms": False,
        # ----------------------------------------------------------------------------#
        # txt的保存路径
        # ----------------------------------------------------------------------------#
        "txt_path": None,
        # ----------------------------------------------------------------------------#
        # 是否保存结果到txt
        # ----------------------------------------------------------------------------#
        "save_txt": False,
        # ----------------------------------------------------------------------------#
        # 是否保存置信度
        # ----------------------------------------------------------------------------#
        "save_conf": False,
        # ----------------------------------------------------------------------------#
        # 是否展示图片
        # ----------------------------------------------------------------------------#
        "view_img": True
        # ----------------------------------------------------------------------------#
    }

    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)
            self._defaults[name] = value
        self.generate()
        self.show_config(**self._defaults)

    def generate(self):
        assert os.path.exists(self.weights), "weights path does not exist..."
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model = attempt_load(self.weights, map_location=self.device)  # load FP32 model
        self.model.to(self.device)
        self.model.eval()
        # Get names and colors
        self.names = self.model.module.names if hasattr(self.model, 'module') else self.model.names
        self.colors = [[random.randint(0, 255) for _ in range(3)] for _ in self.names]

    def show_config(self, **kwargs):
        print('Configurations:')
        print('-' * 70)
        print('|%25s | %40s|' % ('keys', 'values'))
        print('-' * 70)
        for key, value in kwargs.items():
            print('|%25s | %40s|' % (str(key), str(value)))
        print('-' * 70)

    def detect(self, img):
        im0 = img
        stride = int(self.model.stride.max())  # model stride
        imgsz = check_img_size(self.imgsz, s=stride)  # check img_size
        img = letterbox(im0, imgsz, stride=stride)[0]
        # Convert
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB,to 3x416x416
        img = np.ascontiguousarray(img)

        img = torch.from_numpy(img).to(self.device)
        img = img.half() if self.half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Inference
        t1 = time_synchronized()
        pred = self.model(img, augment=self.augment)[0]

        # Apply NMS
        pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, classes=self.classes,
                                   agnostic=self.agnostic_nms)
        t2 = time_synchronized()
        for i, det in enumerate(pred):  # detections per image
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
                # for c in det[:, -1].unique():
                #     n = (det[:, -1] == c).sum()  # detections per class
                # Write results
                for *xyxy, conf, cls in reversed(det):
                    conf = conf.item()
                    cls = int(cls.item())
                    # x0 = int(xyxy[0].item())
                    # y0 = int(xyxy[1].item())
                    # x1 = int(xyxy[2].item())
                    # y1 = int(xyxy[3].item())
                    if self.save_txt:  # Write to file
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                        line = (cls, *xywh, conf) if self.save_conf else (cls, *xywh)  # label format
                        with open(self.txt_path + '.txt', 'a') as f:
                            f.write(('%g ' * len(line)).rstrip() % line + '\n')
                    if self.view_img:  # Add bbox to image
                        label = f'{self.names[int(cls)]} {conf:.2f}'
                        plot_one_box(xyxy, im0, label=label, color=self.colors[int(cls)], line_thickness=3)
        return im0

if __name__ == '__main__':
    yolo = YOLO()

2、调用

    if mode == "predict":
        while True:
            img = input('Input image filename:')
            try:
                image = cv2.imread(img)
            except:
                print('Open Error! Try again!')
                continue
            else:
                with torch.no_grad():
                    r_image = yolo.detect(image)
                    cv2.imshow('0', r_image)
                    cv2.waitKey(0)

    elif mode == "video":
        capture = cv2.VideoCapture(video_path)
        if video_save_path!="":
            fourcc  = cv2.VideoWriter_fourcc(*'XVID')
            size    = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
            out     = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)

        ref, frame = capture.read()
        if not ref:
            raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")

        fps = 0.0
        while(True):
            t1 = time.time()
            # 读取某一帧
            ref, frame = capture.read()
            if not ref:
                break
            # 格式转变,BGRtoRGB
            # frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
            # 转变成Image
            # frame = Image.fromarray(np.uint8(frame))
            # 进行检测
            frame = yolo.detect(frame)
            # RGBtoBGR满足opencv显示格式
            # frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
            
            fps  = ( fps + (1./(time.time()-t1))) / 2.0
            print("fps= %.2f"%(fps))
            frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            
            cv2.imshow("video",frame)
            c= cv2.waitKey(1) & 0xff 
            if video_save_path!="":
                out.write(frame)

            if c==27:
                capture.release()
                break

        print("Video Detection Done!")
        capture.release()
        if video_save_path!="":
            print("Save processed video to the path :" + video_save_path)
            out.release()
        cv2.destroyAllWindows()
  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值