YOLOV5检测代码detect.py注释与解析

YOLOv5代码注释版更新啦,注释的是最近的2021.07.14的版本,且注释更全
github: https://github.com/Laughing-q/yolov5_annotations

YOLOV5检测代码detect.py注释与解析

本文主要对ultralytics\yolov5-v2.0版本的测试代码detect.py的解析,现在v5已经更新了-v3.0版本, 但该代码部分基本上不会有很大的改动,故以下注释与解析都是适用的;当然如果有大改动,笔者也会更新注释。
yolov5其他代码解析

检测参数以及main函数解析

if __name__ == '__main__':
    """
    weights:训练的权重
    source:测试数据,可以是图片/视频路径,也可以是'0'(电脑自带摄像头),也可以是rtsp等视频流
    output:网络预测之后的图片/视频的保存路径
    img-size:网络输入图片大小
    conf-thres:置信度阈值
    iou-thres:做nms的iou阈值
    device:设置设备
    view-img:是否展示预测之后的图片/视频,默认False
    save-txt:是否将预测的框坐标以txt文件形式保存,默认False
    classes:设置只保留某一部分类别,形如0或者0 2 3
    agnostic-nms:进行nms是否也去除不同类别之间的框,默认False
    augment:推理的时候进行多尺度,翻转等操作(TTA)推理
    update:如果为True,则对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为False
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', nargs='+', type=str, default='', help='model.pt path(s)')
    parser.add_argument('--source', type=str, default='inference/images', help='source')  # file/folder, 0 for webcam
    parser.add_argument('--output', type=str, default='inference/output', help='output folder')  # output folder
    parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.65, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.5, help='IOU threshold for NMS')
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--view-img', action='store_true', help='display results')
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    parser.add_argument('--augment', action='store_true', help='augmented inference')
    parser.add_argument('--update', action='store_true', help='update all models')
    opt = parser.parse_args()
    print(opt)

    with torch.no_grad():
        if opt.update:  # update all models (to fix SourceChangeWarning)
            for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt']:
                detect()
                # 去除pt文件中的优化器等信息
                strip_optimizer(opt.weights)
        else:
            detect()

detect函数解析

import argparse
import os
import platform
import shutil
import time
from pathlib import Path

import cv2
import torch
import torch.backends.cudnn as cudnn
from numpy import random

from models.experimental import attempt_load
from utils.datasets import LoadStreams, LoadImages
from utils.general import (
    check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, plot_one_box, strip_optimizer)
from utils.torch_utils import select_device, load_classifier, time_synchronized


def detect(save_img=False):
    # 获取输出文件夹,输入源,权重,参数等参数
    out, source, weights, view_img, save_txt, imgsz = \
        opt.output, opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
    webcam = source == '0' or source.startswith('rtsp') or source.startswith('http') or source.endswith('.txt')

    # Initialize
    # 获取设备
    device = select_device(opt.device)
    # 移除之前的输出文件夹
    if os.path.exists(out):
        shutil.rmtree(out)  # delete output folder
    os.makedirs(out)  # make new output folder
    # 如果设备为gpu,使用Float16
    half = device.type != 'cpu'  # half precision only supported on CUDA

    # Load model
    # 加载Float32模型,确保用户设定的输入图片分辨率能整除32(如不能则调整为能整除并返回)
    model = attempt_load(weights, map_location=device)  # load FP32 model
    imgsz = check_img_size(imgsz, s=model.stride.max())  # check img_size
    # 设置Float16
    if half:
        model.half()  # to FP16

    # Second-stage classifier
    # 设置第二次分类,默认不使用
    classify = False
    if classify:
        modelc = load_classifier(name='resnet101', n=2)  # initialize
        modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model'])  # load weights
        modelc.to(device).eval()

    # Set Dataloader
    # 通过不同的输入源来设置不同的数据加载方式
    vid_path, vid_writer = None, None
    if webcam:
        view_img = True
        cudnn.benchmark = True  # set True to speed up constant image size inference
        dataset = LoadStreams(source, img_size=imgsz)
    else:
        save_img = True
        # 如果检测视频的时候想显示出来,可以在这里加一行view_img = True
        view_img = True
        dataset = LoadImages(source, img_size=imgsz)

    # Get names and colors
    # 获取类别名字
    names = model.module.names if hasattr(model, 'module') else model.names
    # 设置画框的颜色
    colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]

    # Run inference
    t0 = time.time()
    # 进行一次前向推理,测试程序是否正常
    img = torch.zeros((1, 3, imgsz, imgsz), device=device)  # init img
    _ = model(img.half() if half else img) if device.type != 'cpu' else None  # run once
    """
    path 图片/视频路径
    img 进行resize+pad之后的图片
    img0 原size图片
    cap 当读取图片时为None,读取视频时为视频源
    """
    for path, img, im0s, vid_cap in dataset:
        img = torch.from_numpy(img).to(device)
        # 图片也设置为Float16
        img = img.half() if half else img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        # 没有batch_size的话则在最前面添加一个轴
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Inference
        t1 = time_synchronized()
        # print("preprocess_image:", t1 - t0)
        # t1 = time.time()
        """
        前向传播 返回pred的shape是(1, num_boxes, 5+num_class)
        h,w为传入网络图片的长和宽,注意dataset在检测时使用了矩形推理,所以这里h不一定等于w
        num_boxes = h/32 * w/32 + h/16 * w/16 + h/8 * w/8
        pred[..., 0:4]为预测框坐标
        预测框坐标为xywh(中心点+宽长)格式
        pred[..., 4]为objectness置信度
        pred[..., 5:-1]为分类结果
        """
        pred = model(img, augment=opt.augment)[0]
        t1_ = time_synchronized()
        print('inference:', t1_ - t1)

        # Apply NMS
        # 进行NMS
        """
        pred:前向传播的输出
        conf_thres:置信度阈值
        iou_thres:iou阈值
        classes:是否只保留特定的类别
        agnostic:进行nms是否也去除不同类别之间的框
        经过nms之后,预测框格式:xywh-->xyxy(左上角右下角)
        pred是一个列表list[torch.tensor],长度为batch_size
        每一个torch.tensor的shape为(num_boxes, 6),内容为box+conf+cls
        """
        pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
        t2 = time_synchronized()
        # t2 = time.time()

        # Apply Classifier
        # 添加二次分类,默认不使用
        if classify:
            pred = apply_classifier(pred, modelc, img, im0s)

        # Process detections
        # 对每一张图片作处理
        for i, det in enumerate(pred):  # detections per image
            # 如果输入源是webcam,则batch_size不为1,取出dataset中的一张图片
            if webcam:  # batch_size >= 1
                p, s, im0 = path[i], '%g: ' % i, im0s[i].copy()
            else:
                p, s, im0 = path, '', im0s
            # 设置保存图片/视频的路径
            save_path = str(Path(out) / Path(p).name)
            # 设置保存框坐标txt文件的路径
            txt_path = str(Path(out) / Path(p).stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
            # 设置打印信息(图片长宽)
            s += '%gx%g ' % img.shape[2:]  # print string
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            if det is not None and len(det):
                # Rescale boxes from img_size to im0 size
                # 调整预测框的坐标:基于resize+pad的图片的坐标-->基于原size图片的坐标
                # 此时坐标格式为xyxy

                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                # 打印检测到的类别数量
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += '%g %ss, ' % (n, names[int(c)])  # add to string

                # Write results
                # 保存预测结果
                for *xyxy, conf, cls in det:
                    if save_txt:  # Write to file
                        # 将xyxy(左上角+右下角)格式转为xywh(中心点+宽长)格式,并除上w,h做归一化,转化为列表再保存
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                        with open(txt_path + '.txt', 'a') as f:
                            f.write(('%g ' * 5 + '\n') % (cls, *xywh))  # label format
                    # 在原图上画框
                    if save_img or view_img:  # Add bbox to image
                        label = '%s %.2f' % (names[int(cls)], conf)
                        plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)

            # Print time (inference + NMS)
            # 打印前向传播+nms时间
            print('%sDone. (%.3fs)' % (s, t2 - t1))

            # Stream results
            # 如果设置展示,则show图片/视频
            if view_img:
                cv2.imshow(p, im0)
                if cv2.waitKey(1) == ord('q'):  # q to quit
                    raise StopIteration

            # Save results (image with detections)
            # 设置保存图片/视频
            if save_img:
                if dataset.mode == 'images':
                    cv2.imwrite(save_path, im0)
                else:
                    if vid_path != save_path:  # new video
                        vid_path = save_path
                        if isinstance(vid_writer, cv2.VideoWriter):
                            vid_writer.release()  # release previous video writer

                        fourcc = 'mp4v'  # output video codec
                        fps = vid_cap.get(cv2.CAP_PROP_FPS)
                        w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                        h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                        vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
                    vid_writer.write(im0)

    if save_txt or save_img:
        print('Results saved to %s' % Path(out))
        # 打开保存图片和txt的路径(好像只适用于MacOS系统)
        if platform == 'darwin' and not opt.update:  # MacOS
            os.system('open ' + save_path)
    # 打印总时间
    print('Done. (%.3fs)' % (time.time() - t0))
  • 178
    点赞
  • 917
    收藏
    觉得还不错? 一键收藏
  • 416
    评论
### 回答1: YOLOv5 中的 detect.py 文件是用来进行目标检测的主要脚本。它包含了对输入图片/视频进行预处理、模型预测、后处理等一系列操作。其中包含了许多函数,可以帮助我们更好的理解 YOLOv5 的工作原理和实现。 ### 回答2: Yolov5detect.py是YOLO中最为重要的文件之一,是实现目标检测的主要文件。这个文件的代码详解包括以下几个方面。 1.导入必要模块和包: detect.py首先要导入必要模块和包,例如PyTorch中的一些工具包、一些模型(如yolov5)、数据增强、摄像头、命令行参数等等。这个步骤是整个代码的必要内容,以保证下面的代码可以正常运行。 2.加载模型并设置设备: 在detect.py文件中,我们需要通过调用指定的模型(如yolov5s、yolov5m、yolov5l和yolov5x)以及相关的预训练权重来进行目标检测。在完成模型加载后,我们需要根据运行环境设置设备,例如,如果有可用GPU,我们可以将模型放到GPU中来进行运算。 3.载入图片或视频: 在进行目标检测时,我们需要载入待处理的图片或视频文件。通过调用OpenCV的相关功能,我们可以从本地文件或网络直播摄像机中读取视频,而从本地文件夹中读取图片。 4.预处理: 预处理是在将图片或视频传输到模型中进行处理之前进行的。在yolov5 detect.py文件中,主要进行以下预处理: (1)调整大小:将图片或视频帧调整至模型所要求的大小。 (2)转化色彩空间:将彩色图片转化为灰度图片或者RGB色空间。 (3)标准化像素值:调整图片或视频帧的像素值范围。 (4)转置和转换格式:对于输入数据,需要将其转置并以适当的格式进行存储。 5.执行推理(inference): 在推理过程中,将预处理后的数据输入到模型中,得到模型的输出(包括检测框、类别、置信度等信息)。这里是整个代码的核心部分,包括前向传播的计算和预测输出的后处理过程。其中,NMS(non-max suppression)是非常关键的一步,因为它能有效减少多余的检测框,精简输出结果。 6.后处理: 预测结果需要进行一些后处理,包括: (1)将检测框转换为像素坐标。 (2)根据置信度和IoU(Intersection over Union)过滤检测框。 (3)在图片或帧上绘制检测框、标签和置信度等信息。 (4)最后,将处理后的图片或视频帧输出到指定位置。 综上所述,yolov5 detect.py文件是实现目标检测的核心文件,通过对文件每一部分的详解,可以更好地理解代码的含义和作用。 ### 回答3: YOLOv5是目前最优秀的目标检测网络之一。在它的代码中,detect.py文件是用来实现检测过程的。下面我们来详细分析一下该文件的代码。 首先,我们需要导入一些必要的库,这些库包括以及它们所提供的模块,如torch、models、utils、general等。然后,我们需要加载一些模型配置文件和权重文件,它们通常是在训练过程中生成的。我们可以从命令行参数中读取这些文件的路径和一些其他的参数信息,比如输入图片的分辨率、置信度阈值和NMS的参数等等。 然后,我们要加载模型并设置为评估模式。这里加载模型的方式是通过配置文件中指定的模型类型和权重文件的路径来进行加载。在模型加载完成后,我们要为检测结果生成一个输出文件的路径和名称。在检测结果输出文件中,每一行的格式是“image_path confidence x_min y_min x_max y_max label”。 接下来,我们要遍历输入图片的路径集合,对于每张输入图片,我们要先对其进行预处理。这个预处理过程包括将图片转换为模型需要的数据格式、将数据放入GPU中进行推理等。然后,我们要对图像进行前向传递,并根据置信度阈值和NMS的参数,筛选出置信度较高的目标框物体。最后,我们将结果写入输出文件,以供后续的处理和分析。 总的来说,detect.py文件主要是用于对输入的图片进行目标检测,它将加载预训练模型和配置文件,并将检测结果写入输出文件。它还提供了一些可配置的参数,比如置信度阈值和NMS的参数,这些参数可以帮助我们调整模型的检测效果和性能。整个检测过程需要先对输入图片进行预处理,然后进行前向传递和筛选,并将最终结果写入输出文件。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 416
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值