YOLOv5的detect.py部分代码详解

detcet.py详解

代码结构总结

全局导入部分

  1. 导入安装好的库
  2. 获取当前文件绝对路径
  3. 加载自定义的模块
"""
# 这段代码是一个 Python 脚本,它导入了一些模块和库,并定义了一些变量。
#   import argparse: 这个模块用于解析命令行参数和生成帮助信息。
#   import csv: 这个模块提供了对 CSV 文件的读写支持。
#   import os: 这个模块提供了与操作系统交互的功能,例如文件和目录操作。
#   import platform: 这个模块提供了访问底层平台信息的功能,例如操作系统类型和版本。
#   import sys: 这个模块提供了与 Python 解释器交互的功能,例如访问命令行参数和退出程序。
#   from pathlib import Path: 这个类提供了处理文件路径的功能。
# 这段代码的意思是导入所需的模块和库,以便在后续的代码中使用它们。
import argparse
import csv
import os
import platform
import sys
from pathlib import Path

import torch

# 这段代码执行了以下操作:
#   FILE = Path(__file__).resolve(): 这一行创建了一个 Path 对象 FILE,表示当前脚本的绝对路径,并且通过 resolve() 方法确保它是一个绝对路径。
#   ROOT = FILE.parents[0]: 这一行获取了 FILE 的父目录,即当前脚本所在的目录的父目录,将其赋值给 ROOT 变量。在这个脚本中,ROOT 可能是 YOLOv5 项目的根目录。
#   if str(ROOT) not in sys.path:: 这一行检查 ROOT 是否已经存在于 Python 模块搜索路径 sys.path 中
#   sys.path.append(str(ROOT)): 如果 ROOT 不在 sys.path 中,那么将 ROOT 转换为字符串,并将其添加到 sys.path 中,以便 Python 解释器可以在其中查找模块。
#   ROOT = Path(os.path.relpath(ROOT, Path.cwd())): 这一行使用 os.path.relpath() 函数将 ROOT 转换为相对于当前工作目录的相对路径,并重新赋值给 ROOT 变量。
# 这段代码的目的是将当前脚本所在的 YOLOv5 项目的根目录添加到 Python 模块搜索路径中,并将其转换为相对路径形式。这样做可以确保在后续的代码中可以轻松地导入 YOLOv5 项目中的模块。
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative

# 这些都是用户自定义的库,由于上一步已经把路径加载上了,所以现在可以导入,这个顺序不可以调换。具体来说,代码从如下几个文件中导入了部分函数和类:
#   models.common.py: 这个文件定义了一些通用的函数和类,比如图像的处理、非极大值抑制等等。
#   utils.dataloaders.py: 这个文件定义了两个类,Loadlmages和LoadStreams,它们可以加载图像或视频帧,并对它们进行一些预处理,以便进行物体检测或识别。
#   utils.general.py: 这个文件定义了一些常用的工具函数,比如检査文件是否存在、检査图像大小是否符合要求、打印命令行参数等等。
#   ultralytics.utils.plotting.py: 这个文件定义了Annotator类,可以在图像上绘制矩形框和标注信息,utils.torch_utils.py: 这个文件定义了一些与PyTorch有关的工具函数,比如选择设备、同步时间等等
# 通过导入这些模块,可以更方便地进行目标检测的相关任务,并且减少了代码的复杂度和冗余。
from ultralytics.utils.plotting import Annotator, colors, save_one_box

from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (
    LOGGER,
    Profile,
    check_file,
    check_img_size,
    check_imshow,
    check_requirements,
    colorstr,
    cv2,
    increment_path,
    non_max_suppression,
    print_args,
    scale_boxes,
    strip_optimizer,
    xyxy2xywh,
)
from utils.torch_utils import select_device, smart_inference_mode

执行main函数部分

  1. 解析命令行参数
  2. 执行main函数
    • 检查环境是否都安好了
    • 执行run函数,并传入 命令行参
def main(opt):
    """Executes YOLOv5 model inference with given options, checking requirements before running the model."""
    # 检查环境是否都安装完毕,排除掉两个不检查
    check_requirements(ROOT / "requirements.txt", exclude=("tensorboard", "thop"))
    # 执行run函数,将命令行参数传进去
    run(**vars(opt))

# 命令使用:python detect.py --weights runs/train/exp9/weights/best.pt --source inference/images --conf-thres=0.9
if __name__ == "__main__":
    # 加载命令行参数
    opt = parse_opt()
    main(opt)

设置opt参数部分

# -weights: 训练的权重路径,可以使用自己训练的权重,也可以使用官网提供的权重。默认官网的权重yolov5s.pt(yolov5n.pt/yolov5s.ptyolov5m.ptyolov5l.ptyolov5x.pt/区别在于网络的宽度和深度以此增加)
# -source:测试数据,,可以是图片/视频路径,也可以是"0'(电脑自带摄像头),也可以是rtsp等视频流,默认data/images
# -data: 配置数据文件路径,包括image/label/classes等信息,训练自己的文件,需要作相应更改,可以不用管
# -imgsz:预测时网络输入图片的尺寸,默认值为[640]
# -conf-thres:置信度阈值,默认为 0.50
# -iou-thres :非极大抑制时的 loU 阈值,默认为 0.45
# -max-det:保留的最大检测框数量,每张图片中检测目标的个数最多为1000类
# -device:使用的设备,可以是 cuda 设备的 ID(例如 0、0,1,2,3)或者是'cpu',默认为'0'
# -view-img:是否展示预测之后的图片/视频,默认False
# -save-txt: 是否将预测的框坐标以txt文件形式保存,默认False,使用--save-txt 在路径runs/detect/exp*/labels/*.txt下生成每张图片预测的txt文件
# -save-conf:是否保存检测结果的置信度到 txt文件,默认为 False
# -save-crop:是否保存裁剪预测框图片,默认为False,使用--save-crop 在runs/detect/exp*/crop/剪切类别文件夹/ 路径下会保存每个接下来的目标
# -nosave:不保存图片、视频,要保存图片,不设置--nosave 在runs/detectexp*/会出现预测的结果
# -classes:仅检测指定类别,默认为 None
# --agnostic-nms:是否使用类别不敏感的非极大抑制(即不考虑类别信息),默认为 False
# -augment:是否使用数据增强进行推理,默认为False
# -visualize:是否可视化特征图,默认为 False
# -update: 如果为True,则对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为False
# -project:结果保存的项目目录路径,默认为'ROOT/runs/detect'
# -name:结果保存的子目录名称,默认为'exp
# -exist-ok: 是否覆盖已有结果,默认为 False
# -line-thickness:画 bounding box 时的线条宽度,默认为 3
# -hide-labels:是否隐藏标签信息,默认为 False
# -hide-conf:是否隐藏置信度信息,默认为 False
# -half:是否使用 FP16 半精度进行推理,默认为 False
# -dnn:是否使用 OpenCV DNN 进行 ONNX 推理,默认为 False
def parse_opt():
    """Parses command-line arguments for YOLOv5 detection, setting inference options and model configurations."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "yolov5s.pt", help="model path or triton URL")
    parser.add_argument("--source", type=str, default=ROOT / "data/images", help="file/dir/URL/glob/screen/0(webcam)")
    parser.add_argument("--data", type=str, default=ROOT / "data/coco128.yaml", help="(optional) dataset.yaml path")
    parser.add_argument("--imgsz", "--img", "--img-size", nargs="+", type=int, default=[640], help="inference size h,w")
    parser.add_argument("--conf-thres", type=float, default=0.25, help="confidence threshold")
    parser.add_argument("--iou-thres", type=float, default=0.45, help="NMS IoU threshold")
    parser.add_argument("--max-det", type=int, default=1000, help="maximum detections per image")
    parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu")
    parser.add_argument("--view-img", action="store_true", help="show results")
    parser.add_argument("--save-txt", action="store_true", help="save results to *.txt")
    parser.add_argument("--save-csv", action="store_true", help="save results in CSV format")
    parser.add_argument("--save-conf", action="store_true", help="save confidences in --save-txt labels")
    parser.add_argument("--save-crop", action="store_true", help="save cropped prediction boxes")
    parser.add_argument("--nosave", action="store_true", help="do not save images/videos")
    parser.add_argument("--classes", nargs="+", type=int, help="filter by class: --classes 0, or --classes 0 2 3")
    parser.add_argument("--agnostic-nms", action="store_true", help="class-agnostic NMS")
    parser.add_argument("--augment", action="store_true", help="augmented inference")
    parser.add_argument("--visualize", action="store_true", help="visualize features")
    parser.add_argument("--update", action="store_true", help="update all models")
    parser.add_argument("--project", default=ROOT / "runs/detect", help="save results to project/name")
    parser.add_argument("--name", default="exp", help="save results to project/name")
    parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
    parser.add_argument("--line-thickness", default=3, type=int, help="bounding box thickness (pixels)")
    parser.add_argument("--hide-labels", default=False, action="store_true", help="hide labels")
    parser.add_argument("--hide-conf", default=False, action="store_true", help="hide confidences")
    parser.add_argument("--half", action="store_true", help="use FP16 half-precision inference")
    parser.add_argument("--dnn", action="store_true", help="use OpenCV DNN for ONNX inference")
    parser.add_argument("--vid-stride", type=int, default=1, help="video frame-rate stride")
    opt = parser.parse_args()
    opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1  # expand
    print_args(vars(opt)) # 打印所有参数信息
    return opt

run函数部分

  1. 载入命令行参数
  2. 初始化配置:source、save_img等
  3. 得到结果保存路径
  4. 加载模型
  5. 加载数据
  6. 推理部分
    • 先对模型进行热身(预处理)
    • 使用for循环遍历每一帧或每一张图像
      • 图像预处理:维度、归一化等
      • 前向推理,得到推理的预测框
      • 对预测框框执行非极大值抑制
      • 把所有检测框画到原图中:用for循环遍历每个检测框
        • 在原图上绘制检测框
        • 判断要在窗口显示吗?
        • 判断要保存图像吗?
      • 日志打印每张图像所用时间
  7. 在终端打印出运行结果
# run函数,接受opt参数。但是如果出现问题,这里也有参数默认值被设置写死了,都是命令行参数的默认值
'''=========================1.载入参数==========================='''
@smart_inference_mode()
def run(
    weights=ROOT / "yolov5s.pt",  # model path or triton URL
    source=ROOT / "data/images",  # file/dir/URL/glob/screen/0(webcam)
    data=ROOT / "data/coco128.yaml",  # dataset.yaml path
    imgsz=(640, 640),  # inference size (height, width)
    conf_thres=0.25,  # confidence threshold
    iou_thres=0.45,  # NMS IOU threshold
    max_det=1000,  # maximum detections per image
    device="",  # cuda device, i.e. 0 or 0,1,2,3 or cpu
    view_img=False,  # show results
    save_txt=False,  # save results to *.txt
    save_csv=False,  # save results in CSV format
    save_conf=False,  # save confidences in --save-txt labels
    save_crop=False,  # save cropped prediction boxes
    nosave=False,  # do not save images/videos
    classes=None,  # filter by class: --class 0, or --class 0 2 3
    agnostic_nms=False,  # class-agnostic NMS
    augment=False,  # augmented inference
    visualize=False,  # visualize features
    update=False,  # update all models
    project=ROOT / "runs/detect",  # save results to project/name
    name="exp",  # save results to project/name
    exist_ok=False,  # existing project/name ok, do not increment
    line_thickness=3,  # bounding box thickness (pixels)
    hide_labels=False,  # hide labels
    hide_conf=False,  # hide confidences
    half=False,  # use FP16 half-precision inference
    dnn=False,  # use OpenCV DNN for ONNX inference
    vid_stride=1,  # video frame-rate stride
):
    '''=========================2.初始化配置==========================='''
    # 这段代码主要用于处理输入来源。定义了一些布尔值区分输入是图片、视频、网络流还是摄像头。
    # 首先将source转换为字符串类型,然后判断是否需要保存输出结果。如果nosave和source的后缀不是.txt,则会保存输出结果。
    # 接着根据source的类型,确定输入数据的类型:
    #   如果source的后缀是图像或视频格式之一,那么将is file设置为True;
    #   如果sourcel以rtsp://、rtmp:/1、http://、https:/开头,那么将is url设置为True;
    #   如果source是数字或以.txt结尾或是一个URL,那么将webcam设置为True;
    #   如果source既是文件又是URL,那么会调用check file函数下载文件。
    source = str(source)
    save_img = not nosave and not source.endswith(".txt")  # save inference images
    is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
    is_url = source.lower().startswith(("rtsp://", "rtmp://", "http://", "https://"))
    webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
    screenshot = source.lower().startswith("screen")
    if is_url and is_file:
        # 这个函数的作用是根据输入的文件名或路径进行文件查找和下载,并返回文件的路径。
        source = check_file(source)  # download

    '''=========================3.保存结果==========================='''
    # Directories
    # 这个函数的主要作用是生成一个新的文件或目录路径,如果已存在,则在其基础上增加一个数字后缀,以避免重复。
    save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)  # increment run
    # 根据上面生成的路径创建文件夹
    (save_dir / "labels" if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir

    # Load model
    '''=========================4.加载模型==========================='''
    device = select_device(device)
    # DetectMultiBackend定义在models.common模块中,是我们要加载的网络模型,其中weights参数就是输入时指定的权重文件(比如yolov5s.pt)
    model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)# 将其赋值给model
    '''
            stride:推理时所用到的步长,默认为32, 大步长适合于大目标,小步长适合于小目标
            names:保存推理结果名的列表,比如默认模型的值是['person', 'bicycle', 'car', ...] 
            pt: 加载的是否是pytorch模型(也就是pt格式的文件)
    '''
    stride, names, pt = model.stride, model.names, model.pt
    # 确保输入图片的尺寸imgsz能整除stride=32 如果不能则调整为能被整除并返回
    imgsz = check_img_size(imgsz, s=stride)  # check image size


    # Dataloader
    '''=========================5.加载数据==========================='''
    bs = 1  # batch_size
    # 通过不同的输入源来设置不同的数据加载方式
    if webcam:# 使用摄像头作为输入
        view_img = check_imshow(warn=True)#该函数的目的是在特定环境中检查图像显示功能是否可用,如果不可用,则根据需要发出警告。
        '''
                 source:输入数据源;image_size 图片识别前被放缩的大小;stride:识别时的步长, 
                 auto的作用可以看utils.augmentations.letterbox方法,它决定了是否需要将图片填充为正方形,如果auto=True则不需要
        '''
        dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)# 加载输入数据流
        bs = len(dataset)
    elif screenshot:
        dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
    else:# 直接从source文件夹下读取图片
        dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
    vid_path, vid_writer = [None] * bs, [None] * bs


    '''=========================6.推理部分==========================='''
    # Run inference
    '''推理部分:推理部分是整个算法的核心部分,通过for循环对加载的数据进行遍历,一帧一帧的推理,进行非极大值抑制,绘制bounding box、预测类别。'''
    # warmup是 model也就是DetectMultiBackend这个类中的一个函数,对模型进行预处理以加速后续的推理过程
    model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))  # warmup
    seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
    # 遍历数据集:使用for循环遍历数据集中的每一帧图像。数据集中的每个元素包括路径(path)、图像数据(im)、原始图像(im0s)、视频捕获对象(vid_cap)和尺度(s)。
    for path, im, im0s, vid_cap, s in dataset:
        with dt[0]:# 这是一个上下文管理器,用于记录预处理的时间。
            im = torch.from_numpy(im).to(model.device)
            im = im.half() if model.fp16 else im.float()  # uint8 to fp16/32
            im /= 255  # 0 - 255 to 0.0 - 1.0
            if len(im.shape) == 3:# 如果图像张量的维度为3,则在最前面添加一个维度,用于表示batch大小。
                im = im[None]  # expand for batch dim
            if model.xml and im.shape[0] > 1:
                ims = torch.chunk(im, im.shape[0], 0)

        # Inference
        with dt[1]:# 这是一个上下文管理器,用于记录推理和可视化的时间。
            visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
            if model.xml and im.shape[0] > 1:
                pred = None
                for image in ims:
                    if pred is None:
                        pred = model(image, augment=augment, visualize=visualize).unsqueeze(0)
                    else:
                        pred = torch.cat((pred, model(image, augment=augment, visualize=visualize).unsqueeze(0)), dim=0)
                pred = [pred, None]
            else:
                # 前向推理
                # pred表示模型预测出来的所有框框
                pred = model(im, augment=augment, visualize=visualize)
        # NMS
        with dt[2]:# 这是一个上下文管理器,用于记录非极大值抑制的时间。
            # 对预测结果进行非极大值抑制,剔除重叠度较高的边界框。
            pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)

        # Second-stage classifier (optional)
        # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)

        # Define the path for the CSV file
        csv_path = save_dir / "predictions.csv"

        # Create or append to the CSV file
        def write_to_csv(image_name, prediction, confidence):
            """Writes prediction data for an image to a CSV file, appending if the file exists."""
            data = {"Image Name": image_name, "Prediction": prediction, "Confidence": confidence}
            with open(csv_path, mode="a", newline="") as f:
                writer = csv.DictWriter(f, fieldnames=data.keys())
                if not csv_path.is_file():
                    writer.writeheader()
                writer.writerow(data)

        # Process predictions
        # 预测的过程
        # 把所有检测框画到原图中
        '''
            这段代码使用了一个循环来遍历检测结果列表中的每个物体,并对每个物体进行处理
            循环中的变量“i”是一个索引变量,表示当前正在处理第几个物体,而变量"det"则表示当前物体的检测结果。循环体中的第一行代码"seen += 1"用于增加一个计数器,记录已处理的物体数量。
            接下来,根据是否使用网络摄像头来判断处理单张图像还是批量图像。
                如果使用的是网络摄像头,则代码会遍历每个图像并复制一份备份到变量"im0"中,同时将当前图像的路径和计数器记录到变量"p"和"frame"中。最后,将当前处理的物体索引和相关信息记录到字符串变量"s"中。
                如果没有使用网络摄像头,则会直接使用"im0s"变量中的图像,将图像路径和计数器记录到变量"p"和"frame"中。同时,还会检査数据集中是否有"frame"属性,如果有,则将其值记录到变量"frame"中。
        '''
        for i, det in enumerate(pred):  # per image
            seen += 1
            if webcam:  # batch_size >= 1
                p, im0, frame = path[i], im0s[i].copy(), dataset.count
                s += f"{i}: "
            else:
                '''
                        大部分我们一般都是从LoadImages流读取本都文件中的照片或者视频 所以batch_size=1
                           p: 当前图片/视频的绝对路径 如 F:\yolo_v5\yolov5-U\data\images\bus.jpg
                           im0: 原始图片 letterbox + pad 之前的图片
                           frame: 视频流,此次取的是第几张图片
                '''
                p, im0, frame = path, im0s.copy(), getattr(dataset, "frame", 0)


            p = Path(p)  # to Path
            # 图片/视频的保存路径save_path 如 runs\\detect\\exp8\\fire.jpg
            save_path = str(save_dir / p.name)  # im.jpg
            # 设置保存框坐标的txt文件路径,每张图片对应一个框坐标信息
            txt_path = str(save_dir / "labels" / p.stem) + ("" if dataset.mode == "image" else f"_{frame}")  # im.txt
            # 设置输出图片信息。图片shape (w, h)
            s += "%gx%g " % im.shape[2:]  # print string
            # 得到原图的宽和高
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            # 保存截图。如果save_crop的值为true,则将检测到的bounding_box单独保存成一张图片。
            imc = im0.copy() if save_crop else im0  # for save_crop

            # 这段代码绘制框框
            # 得到一个绘图的类,类中预先存储了原图、线条宽度、类名
            annotator = Annotator(im0, line_width=line_thickness, example=str(names))
            if len(det):
                # Rescale boxes from img_size to im0 size,因为原图进行了放缩
                det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                for c in det[:, 5].unique():
                    n = (det[:, 5] == c).sum()  # detections per class
                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string

                # Write results
                # 每个元素包含对象的边界框坐标(xyxy)、置信度(conf)和类别(cls)。
                for *xyxy, conf, cls in reversed(det):
                    # 这里将类别 cls 转换为整数,并根据是否隐藏置信度来确定标签 label。
                    # 同时将置信度转换为浮点数,并以两位小数形式表示。
                    c = int(cls)  # integer class
                    label = names[c] if hide_conf else f"{names[c]}"
                    confidence = float(conf)
                    confidence_str = f"{confidence:.2f}"

                    if save_csv:
                        write_to_csv(p.name, label, confidence_str)

                    if save_txt:  # Write to file
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                        line = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # label format
                        with open(f"{txt_path}.txt", "a") as f:
                            f.write(("%g " * len(line)).rstrip() % line + "\n")
                    # 只要是不是txt且不是不保存(save_img = not nosave and not source.endswith(".txt"))save_img就是满足的
                    if save_img or save_crop or view_img:  # Add bbox to image
                        c = int(cls)  # integer class
                        label = None if hide_labels else (names[c] if hide_conf else f"{names[c]} {conf:.2f}")
                        # 在图像上绘制边界框和标签
                        annotator.box_label(xyxy, label, color=colors(c, True))
                    if save_crop:
                        save_one_box(xyxy, imc, file=save_dir / "crops" / names[c] / f"{p.stem}.jpg", BGR=True)

            '''
                这段代码片段的功能是将检测结果实时流式传输,并且在需要时保存检测结果。
                具体来说:
                    首先,将检测结果图像 im0 通过 annotator.result() 获取。
                    如果 view_img 为真,则显示图像流。这里使用了 OpenCV 的 cv2.imshow() 函数将图像显示在窗口中,并通过 cv2.waitKey(1) 等待1毫秒,以便允许用户对窗口进行交互操作。
                    接着,在满足条件 save_img 为真时,保存检测结果图像:
                        如果数据集的模式是图像,直接使用 OpenCV 的 cv2.imwrite() 函数将图像保存到指定路径 save_path。
                        如果数据集的模式是视频或流,则执行以下操作:
                            检查当前视频的保存路径是否与之前不同,如果是,则需要初始化一个新的视频写入器。
                            释放之前的视频写入器。
                            如果是视频文件,则获取视频的帧率(fps)、宽度和高度。
                            如果是实时视频流,则假设帧率为 30 帧每秒(fps=30),并获取当前帧的宽度和高度。
                            将保存路径的文件后缀强制设为 ".mp4"。
                            使用 OpenCV 的 cv2.VideoWriter() 函数初始化一个视频写入器,指定视频编解码器为 MP4V,帧率为上面获取的帧率,图像大小为上面获取的宽度和高度。
                            将当前帧 im0 写入视频。
                综上所述,这段代码的功能是实时流式传输检测结果,并在需要时保存这些结果。
            '''
            # Stream results
            im0 = annotator.result()
            if view_img: # 是否在窗口展示图像
                if platform.system() == "Linux" and p not in windows:
                    windows.append(p)
                    cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)  # allow window resize (Linux)
                    cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
                cv2.imshow(str(p), im0)
                cv2.waitKey(1)  # 1 millisecond

            # Save results (image with detections)
            if save_img: # 是否保存图像
                if dataset.mode == "image":
                    cv2.imwrite(save_path, im0)
                else:  # 'video' or 'stream'
                    if vid_path[i] != save_path:  # new video
                        vid_path[i] = save_path
                        if isinstance(vid_writer[i], cv2.VideoWriter):
                            vid_writer[i].release()  # release previous video writer
                        if vid_cap:  # video
                            fps = vid_cap.get(cv2.CAP_PROP_FPS)
                            w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                            h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                        else:  # stream
                            fps, w, h = 30, im0.shape[1], im0.shape[0]
                        save_path = str(Path(save_path).with_suffix(".mp4"))  # force *.mp4 suffix on results videos
                        vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
                    vid_writer[i].write(im0)

        # Print time (inference-only)
        # 每一帧图像,每一张图像:在日志中打印推理时间以及检测结果的数量信息。
        LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")

    # Print results
    '''================7.在终端里打印出运行的结果============================'''
    # 所有图像都处理完毕,打印结果信息,并保存结果
    t = tuple(x.t / seen * 1e3 for x in dt)  # speeds per image平均每张图片所耗费时间
    LOGGER.info(f"Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}" % t)
    if save_txt or save_img:
        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ""# 标签保存的路径
        LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
    if update:
        strip_optimizer(weights[0])  # update model (to fix SourceChangeWarning)

整体detect.py代码

# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
"""
Run YOLOv5 detection inference on images, videos, directories, globs, YouTube, webcam, streams, etc.

Usage - sources:
    $ python detect.py --weights yolov5s.pt --source 0                               # webcam
                                                     img.jpg                         # image
                                                     vid.mp4                         # video
                                                     screen                          # screenshot
                                                     path/                           # directory
                                                     list.txt                        # list of images
                                                     list.streams                    # list of streams
                                                     'path/*.jpg'                    # glob
                                                     'https://youtu.be/LNwODJXcvt4'  # YouTube
                                                     'rtsp://example.com/media.mp4'  # RTSP, RTMP, HTTP stream

Usage - formats:
    $ python detect.py --weights yolov5s.pt                 # PyTorch
                                 yolov5s.torchscript        # TorchScript
                                 yolov5s.onnx               # ONNX Runtime or OpenCV DNN with --dnn
                                 yolov5s_openvino_model     # OpenVINO
                                 yolov5s.engine             # TensorRT
                                 yolov5s.mlmodel            # CoreML (macOS-only)
                                 yolov5s_saved_model        # TensorFlow SavedModel
                                 yolov5s.pb                 # TensorFlow GraphDef
                                 yolov5s.tflite             # TensorFlow Lite
                                 yolov5s_edgetpu.tflite     # TensorFlow Edge TPU
                                 yolov5s_paddle_model       # PaddlePaddle
"""
# 这段代码是一个 Python 脚本,它导入了一些模块和库,并定义了一些变量。
#   import argparse: 这个模块用于解析命令行参数和生成帮助信息。
#   import csv: 这个模块提供了对 CSV 文件的读写支持。
#   import os: 这个模块提供了与操作系统交互的功能,例如文件和目录操作。
#   import platform: 这个模块提供了访问底层平台信息的功能,例如操作系统类型和版本。
#   import sys: 这个模块提供了与 Python 解释器交互的功能,例如访问命令行参数和退出程序。
#   from pathlib import Path: 这个类提供了处理文件路径的功能。
# 这段代码的意思是导入所需的模块和库,以便在后续的代码中使用它们。
import argparse
import csv
import os
import platform
import sys
from pathlib import Path

import torch

# 这段代码执行了以下操作:
#   FILE = Path(__file__).resolve(): 这一行创建了一个 Path 对象 FILE,表示当前脚本的绝对路径,并且通过 resolve() 方法确保它是一个绝对路径。
#   ROOT = FILE.parents[0]: 这一行获取了 FILE 的父目录,即当前脚本所在的目录的父目录,将其赋值给 ROOT 变量。在这个脚本中,ROOT 可能是 YOLOv5 项目的根目录。
#   if str(ROOT) not in sys.path:: 这一行检查 ROOT 是否已经存在于 Python 模块搜索路径 sys.path 中
#   sys.path.append(str(ROOT)): 如果 ROOT 不在 sys.path 中,那么将 ROOT 转换为字符串,并将其添加到 sys.path 中,以便 Python 解释器可以在其中查找模块。
#   ROOT = Path(os.path.relpath(ROOT, Path.cwd())): 这一行使用 os.path.relpath() 函数将 ROOT 转换为相对于当前工作目录的相对路径,并重新赋值给 ROOT 变量。
# 这段代码的目的是将当前脚本所在的 YOLOv5 项目的根目录添加到 Python 模块搜索路径中,并将其转换为相对路径形式。这样做可以确保在后续的代码中可以轻松地导入 YOLOv5 项目中的模块。
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]  # YOLOv5 root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative

# 这些都是用户自定义的库,由于上一步已经把路径加载上了,所以现在可以导入,这个顺序不可以调换。具体来说,代码从如下几个文件中导入了部分函数和类:
#   models.common.py: 这个文件定义了一些通用的函数和类,比如图像的处理、非极大值抑制等等。
#   utils.dataloaders.py: 这个文件定义了两个类,Loadlmages和LoadStreams,它们可以加载图像或视频帧,并对它们进行一些预处理,以便进行物体检测或识别。
#   utils.general.py: 这个文件定义了一些常用的工具函数,比如检査文件是否存在、检査图像大小是否符合要求、打印命令行参数等等。
#   ultralytics.utils.plotting.py: 这个文件定义了Annotator类,可以在图像上绘制矩形框和标注信息,utils.torch_utils.py: 这个文件定义了一些与PyTorch有关的工具函数,比如选择设备、同步时间等等
# 通过导入这些模块,可以更方便地进行目标检测的相关任务,并且减少了代码的复杂度和冗余。
from ultralytics.utils.plotting import Annotator, colors, save_one_box

from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (
    LOGGER,
    Profile,
    check_file,
    check_img_size,
    check_imshow,
    check_requirements,
    colorstr,
    cv2,
    increment_path,
    non_max_suppression,
    print_args,
    scale_boxes,
    strip_optimizer,
    xyxy2xywh,
)
from utils.torch_utils import select_device, smart_inference_mode

# run函数,接受opt参数。但是如果出现问题,这里也有参数默认值被设置写死了,都是命令行参数的默认值
'''=========================1.载入参数==========================='''
@smart_inference_mode()
def run(
    weights=ROOT / "yolov5s.pt",  # model path or triton URL
    source=ROOT / "data/images",  # file/dir/URL/glob/screen/0(webcam)
    data=ROOT / "data/coco128.yaml",  # dataset.yaml path
    imgsz=(640, 640),  # inference size (height, width)
    conf_thres=0.25,  # confidence threshold
    iou_thres=0.45,  # NMS IOU threshold
    max_det=1000,  # maximum detections per image
    device="",  # cuda device, i.e. 0 or 0,1,2,3 or cpu
    view_img=False,  # show results
    save_txt=False,  # save results to *.txt
    save_csv=False,  # save results in CSV format
    save_conf=False,  # save confidences in --save-txt labels
    save_crop=False,  # save cropped prediction boxes
    nosave=False,  # do not save images/videos
    classes=None,  # filter by class: --class 0, or --class 0 2 3
    agnostic_nms=False,  # class-agnostic NMS
    augment=False,  # augmented inference
    visualize=False,  # visualize features
    update=False,  # update all models
    project=ROOT / "runs/detect",  # save results to project/name
    name="exp",  # save results to project/name
    exist_ok=False,  # existing project/name ok, do not increment
    line_thickness=3,  # bounding box thickness (pixels)
    hide_labels=False,  # hide labels
    hide_conf=False,  # hide confidences
    half=False,  # use FP16 half-precision inference
    dnn=False,  # use OpenCV DNN for ONNX inference
    vid_stride=1,  # video frame-rate stride
):
    '''=========================2.初始化配置==========================='''
    # 这段代码主要用于处理输入来源。定义了一些布尔值区分输入是图片、视频、网络流还是摄像头。
    # 首先将source转换为字符串类型,然后判断是否需要保存输出结果。如果nosave和source的后缀不是.txt,则会保存输出结果。
    # 接着根据source的类型,确定输入数据的类型:
    #   如果source的后缀是图像或视频格式之一,那么将is file设置为True;
    #   如果sourcel以rtsp://、rtmp:/1、http://、https:/开头,那么将is url设置为True;
    #   如果source是数字或以.txt结尾或是一个URL,那么将webcam设置为True;
    #   如果source既是文件又是URL,那么会调用check file函数下载文件。
    source = str(source)
    save_img = not nosave and not source.endswith(".txt")  # save inference images
    is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
    is_url = source.lower().startswith(("rtsp://", "rtmp://", "http://", "https://"))
    webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
    screenshot = source.lower().startswith("screen")
    if is_url and is_file:
        # 这个函数的作用是根据输入的文件名或路径进行文件查找和下载,并返回文件的路径。
        source = check_file(source)  # download

    '''=========================3.保存结果==========================='''
    # Directories
    # 这个函数的主要作用是生成一个新的文件或目录路径,如果已存在,则在其基础上增加一个数字后缀,以避免重复。
    save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)  # increment run
    # 根据上面生成的路径创建文件夹
    (save_dir / "labels" if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir

    # Load model
    '''=========================4.加载模型==========================='''
    device = select_device(device)
    # DetectMultiBackend定义在models.common模块中,是我们要加载的网络模型,其中weights参数就是输入时指定的权重文件(比如yolov5s.pt)
    model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)# 将其赋值给model
    '''
            stride:推理时所用到的步长,默认为32, 大步长适合于大目标,小步长适合于小目标
            names:保存推理结果名的列表,比如默认模型的值是['person', 'bicycle', 'car', ...] 
            pt: 加载的是否是pytorch模型(也就是pt格式的文件)
    '''
    stride, names, pt = model.stride, model.names, model.pt
    # 确保输入图片的尺寸imgsz能整除stride=32 如果不能则调整为能被整除并返回
    imgsz = check_img_size(imgsz, s=stride)  # check image size


    # Dataloader
    '''=========================5.加载数据==========================='''
    bs = 1  # batch_size
    # 通过不同的输入源来设置不同的数据加载方式
    if webcam:# 使用摄像头作为输入
        view_img = check_imshow(warn=True)#该函数的目的是在特定环境中检查图像显示功能是否可用,如果不可用,则根据需要发出警告。
        '''
                 source:输入数据源;image_size 图片识别前被放缩的大小;stride:识别时的步长, 
                 auto的作用可以看utils.augmentations.letterbox方法,它决定了是否需要将图片填充为正方形,如果auto=True则不需要
        '''
        dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)# 加载输入数据流
        bs = len(dataset)
    elif screenshot:
        dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
    else:# 直接从source文件夹下读取图片
        dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
    vid_path, vid_writer = [None] * bs, [None] * bs


    '''=========================6.推理部分==========================='''
    # Run inference
    '''推理部分:推理部分是整个算法的核心部分,通过for循环对加载的数据进行遍历,一帧一帧的推理,进行非极大值抑制,绘制bounding box、预测类别。'''
    # warmup是 model也就是DetectMultiBackend这个类中的一个函数,对模型进行预处理以加速后续的推理过程
    model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))  # warmup
    seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
    # 遍历数据集:使用for循环遍历数据集中的每一帧图像。数据集中的每个元素包括路径(path)、图像数据(im)、原始图像(im0s)、视频捕获对象(vid_cap)和尺度(s)。
    for path, im, im0s, vid_cap, s in dataset:
        with dt[0]:# 这是一个上下文管理器,用于记录预处理的时间。
            im = torch.from_numpy(im).to(model.device)
            im = im.half() if model.fp16 else im.float()  # uint8 to fp16/32
            im /= 255  # 0 - 255 to 0.0 - 1.0
            if len(im.shape) == 3:# 如果图像张量的维度为3,则在最前面添加一个维度,用于表示batch大小。
                im = im[None]  # expand for batch dim
            if model.xml and im.shape[0] > 1:
                ims = torch.chunk(im, im.shape[0], 0)

        # Inference
        with dt[1]:# 这是一个上下文管理器,用于记录推理和可视化的时间。
            visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
            if model.xml and im.shape[0] > 1:
                pred = None
                for image in ims:
                    if pred is None:
                        pred = model(image, augment=augment, visualize=visualize).unsqueeze(0)
                    else:
                        pred = torch.cat((pred, model(image, augment=augment, visualize=visualize).unsqueeze(0)), dim=0)
                pred = [pred, None]
            else:
                # 前向推理
                # pred表示模型预测出来的所有框框
                pred = model(im, augment=augment, visualize=visualize)
        # NMS
        with dt[2]:# 这是一个上下文管理器,用于记录非极大值抑制的时间。
            # 对预测结果进行非极大值抑制,剔除重叠度较高的边界框。
            pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)

        # Second-stage classifier (optional)
        # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)

        # Define the path for the CSV file
        csv_path = save_dir / "predictions.csv"

        # Create or append to the CSV file
        def write_to_csv(image_name, prediction, confidence):
            """Writes prediction data for an image to a CSV file, appending if the file exists."""
            data = {"Image Name": image_name, "Prediction": prediction, "Confidence": confidence}
            with open(csv_path, mode="a", newline="") as f:
                writer = csv.DictWriter(f, fieldnames=data.keys())
                if not csv_path.is_file():
                    writer.writeheader()
                writer.writerow(data)

        # Process predictions
        # 预测的过程
        # 把所有检测框画到原图中
        '''
            这段代码使用了一个循环来遍历检测结果列表中的每个物体,并对每个物体进行处理
            循环中的变量“i”是一个索引变量,表示当前正在处理第几个物体,而变量"det"则表示当前物体的检测结果。循环体中的第一行代码"seen += 1"用于增加一个计数器,记录已处理的物体数量。
            接下来,根据是否使用网络摄像头来判断处理单张图像还是批量图像。
                如果使用的是网络摄像头,则代码会遍历每个图像并复制一份备份到变量"im0"中,同时将当前图像的路径和计数器记录到变量"p"和"frame"中。最后,将当前处理的物体索引和相关信息记录到字符串变量"s"中。
                如果没有使用网络摄像头,则会直接使用"im0s"变量中的图像,将图像路径和计数器记录到变量"p"和"frame"中。同时,还会检査数据集中是否有"frame"属性,如果有,则将其值记录到变量"frame"中。
        '''
        for i, det in enumerate(pred):  # per image
            seen += 1
            if webcam:  # batch_size >= 1
                p, im0, frame = path[i], im0s[i].copy(), dataset.count
                s += f"{i}: "
            else:
                '''
                        大部分我们一般都是从LoadImages流读取本都文件中的照片或者视频 所以batch_size=1
                           p: 当前图片/视频的绝对路径 如 F:\yolo_v5\yolov5-U\data\images\bus.jpg
                           im0: 原始图片 letterbox + pad 之前的图片
                           frame: 视频流,此次取的是第几张图片
                '''
                p, im0, frame = path, im0s.copy(), getattr(dataset, "frame", 0)


            p = Path(p)  # to Path
            # 图片/视频的保存路径save_path 如 runs\\detect\\exp8\\fire.jpg
            save_path = str(save_dir / p.name)  # im.jpg
            # 设置保存框坐标的txt文件路径,每张图片对应一个框坐标信息
            txt_path = str(save_dir / "labels" / p.stem) + ("" if dataset.mode == "image" else f"_{frame}")  # im.txt
            # 设置输出图片信息。图片shape (w, h)
            s += "%gx%g " % im.shape[2:]  # print string
            # 得到原图的宽和高
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            # 保存截图。如果save_crop的值为true,则将检测到的bounding_box单独保存成一张图片。
            imc = im0.copy() if save_crop else im0  # for save_crop

            # 这段代码绘制框框
            # 得到一个绘图的类,类中预先存储了原图、线条宽度、类名
            annotator = Annotator(im0, line_width=line_thickness, example=str(names))
            if len(det):
                # Rescale boxes from img_size to im0 size,因为原图进行了放缩
                det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                for c in det[:, 5].unique():
                    n = (det[:, 5] == c).sum()  # detections per class
                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string

                # Write results
                # 每个元素包含对象的边界框坐标(xyxy)、置信度(conf)和类别(cls)。
                for *xyxy, conf, cls in reversed(det):
                    # 这里将类别 cls 转换为整数,并根据是否隐藏置信度来确定标签 label。
                    # 同时将置信度转换为浮点数,并以两位小数形式表示。
                    c = int(cls)  # integer class
                    label = names[c] if hide_conf else f"{names[c]}"
                    confidence = float(conf)
                    confidence_str = f"{confidence:.2f}"

                    if save_csv:
                        write_to_csv(p.name, label, confidence_str)

                    if save_txt:  # Write to file
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                        line = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # label format
                        with open(f"{txt_path}.txt", "a") as f:
                            f.write(("%g " * len(line)).rstrip() % line + "\n")
                    # 只要是不是txt且不是不保存(save_img = not nosave and not source.endswith(".txt"))save_img就是满足的
                    if save_img or save_crop or view_img:  # Add bbox to image
                        c = int(cls)  # integer class
                        label = None if hide_labels else (names[c] if hide_conf else f"{names[c]} {conf:.2f}")
                        # 在图像上绘制边界框和标签
                        annotator.box_label(xyxy, label, color=colors(c, True))
                    if save_crop:
                        save_one_box(xyxy, imc, file=save_dir / "crops" / names[c] / f"{p.stem}.jpg", BGR=True)

            '''
                这段代码片段的功能是将检测结果实时流式传输,并且在需要时保存检测结果。
                具体来说:
                    首先,将检测结果图像 im0 通过 annotator.result() 获取。
                    如果 view_img 为真,则显示图像流。这里使用了 OpenCV 的 cv2.imshow() 函数将图像显示在窗口中,并通过 cv2.waitKey(1) 等待1毫秒,以便允许用户对窗口进行交互操作。
                    接着,在满足条件 save_img 为真时,保存检测结果图像:
                        如果数据集的模式是图像,直接使用 OpenCV 的 cv2.imwrite() 函数将图像保存到指定路径 save_path。
                        如果数据集的模式是视频或流,则执行以下操作:
                            检查当前视频的保存路径是否与之前不同,如果是,则需要初始化一个新的视频写入器。
                            释放之前的视频写入器。
                            如果是视频文件,则获取视频的帧率(fps)、宽度和高度。
                            如果是实时视频流,则假设帧率为 30 帧每秒(fps=30),并获取当前帧的宽度和高度。
                            将保存路径的文件后缀强制设为 ".mp4"。
                            使用 OpenCV 的 cv2.VideoWriter() 函数初始化一个视频写入器,指定视频编解码器为 MP4V,帧率为上面获取的帧率,图像大小为上面获取的宽度和高度。
                            将当前帧 im0 写入视频。
                综上所述,这段代码的功能是实时流式传输检测结果,并在需要时保存这些结果。
            '''
            # Stream results
            im0 = annotator.result()
            if view_img: # 是否在窗口展示图像
                if platform.system() == "Linux" and p not in windows:
                    windows.append(p)
                    cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)  # allow window resize (Linux)
                    cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
                cv2.imshow(str(p), im0)
                cv2.waitKey(1)  # 1 millisecond

            # Save results (image with detections)
            if save_img: # 是否保存图像
                if dataset.mode == "image":
                    cv2.imwrite(save_path, im0)
                else:  # 'video' or 'stream'
                    if vid_path[i] != save_path:  # new video
                        vid_path[i] = save_path
                        if isinstance(vid_writer[i], cv2.VideoWriter):
                            vid_writer[i].release()  # release previous video writer
                        if vid_cap:  # video
                            fps = vid_cap.get(cv2.CAP_PROP_FPS)
                            w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                            h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                        else:  # stream
                            fps, w, h = 30, im0.shape[1], im0.shape[0]
                        save_path = str(Path(save_path).with_suffix(".mp4"))  # force *.mp4 suffix on results videos
                        vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
                    vid_writer[i].write(im0)

        # Print time (inference-only)
        # 每一帧图像,每一张图像:在日志中打印推理时间以及检测结果的数量信息。
        LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")

    # Print results
    '''================7.在终端里打印出运行的结果============================'''
    # 所有图像都处理完毕,打印结果信息,并保存结果
    t = tuple(x.t / seen * 1e3 for x in dt)  # speeds per image平均每张图片所耗费时间
    LOGGER.info(f"Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}" % t)
    if save_txt or save_img:
        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ""# 标签保存的路径
        LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
    if update:
        strip_optimizer(weights[0])  # update model (to fix SourceChangeWarning)

# -weights: 训练的权重路径,可以使用自己训练的权重,也可以使用官网提供的权重。默认官网的权重yolov5s.pt(yolov5n.pt/yolov5s.ptyolov5m.ptyolov5l.ptyolov5x.pt/区别在于网络的宽度和深度以此增加)
# -source:测试数据,,可以是图片/视频路径,也可以是"0'(电脑自带摄像头),也可以是rtsp等视频流,默认data/images
# -data: 配置数据文件路径,包括image/label/classes等信息,训练自己的文件,需要作相应更改,可以不用管
# -imgsz:预测时网络输入图片的尺寸,默认值为[640]
# -conf-thres:置信度阈值,默认为 0.50
# -iou-thres :非极大抑制时的 loU 阈值,默认为 0.45
# -max-det:保留的最大检测框数量,每张图片中检测目标的个数最多为1000类
# -device:使用的设备,可以是 cuda 设备的 ID(例如 0、0,1,2,3)或者是'cpu',默认为'0'
# -view-img:是否展示预测之后的图片/视频,默认False
# -save-txt: 是否将预测的框坐标以txt文件形式保存,默认False,使用--save-txt 在路径runs/detect/exp*/labels/*.txt下生成每张图片预测的txt文件
# -save-conf:是否保存检测结果的置信度到 txt文件,默认为 False
# -save-crop:是否保存裁剪预测框图片,默认为False,使用--save-crop 在runs/detect/exp*/crop/剪切类别文件夹/ 路径下会保存每个接下来的目标
# -nosave:不保存图片、视频,要保存图片,不设置--nosave 在runs/detectexp*/会出现预测的结果
# -classes:仅检测指定类别,默认为 None
# --agnostic-nms:是否使用类别不敏感的非极大抑制(即不考虑类别信息),默认为 False
# -augment:是否使用数据增强进行推理,默认为False
# -visualize:是否可视化特征图,默认为 False
# -update: 如果为True,则对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为False
# -project:结果保存的项目目录路径,默认为'ROOT/runs/detect'
# -name:结果保存的子目录名称,默认为'exp
# -exist-ok: 是否覆盖已有结果,默认为 False
# -line-thickness:画 bounding box 时的线条宽度,默认为 3
# -hide-labels:是否隐藏标签信息,默认为 False
# -hide-conf:是否隐藏置信度信息,默认为 False
# -half:是否使用 FP16 半精度进行推理,默认为 False
# -dnn:是否使用 OpenCV DNN 进行 ONNX 推理,默认为 False
def parse_opt():
    """Parses command-line arguments for YOLOv5 detection, setting inference options and model configurations."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "yolov5s.pt", help="model path or triton URL")
    parser.add_argument("--source", type=str, default=ROOT / "data/images", help="file/dir/URL/glob/screen/0(webcam)")
    parser.add_argument("--data", type=str, default=ROOT / "data/coco128.yaml", help="(optional) dataset.yaml path")
    parser.add_argument("--imgsz", "--img", "--img-size", nargs="+", type=int, default=[640], help="inference size h,w")
    parser.add_argument("--conf-thres", type=float, default=0.25, help="confidence threshold")
    parser.add_argument("--iou-thres", type=float, default=0.45, help="NMS IoU threshold")
    parser.add_argument("--max-det", type=int, default=1000, help="maximum detections per image")
    parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu")
    parser.add_argument("--view-img", action="store_true", help="show results")
    parser.add_argument("--save-txt", action="store_true", help="save results to *.txt")
    parser.add_argument("--save-csv", action="store_true", help="save results in CSV format")
    parser.add_argument("--save-conf", action="store_true", help="save confidences in --save-txt labels")
    parser.add_argument("--save-crop", action="store_true", help="save cropped prediction boxes")
    parser.add_argument("--nosave", action="store_true", help="do not save images/videos")
    parser.add_argument("--classes", nargs="+", type=int, help="filter by class: --classes 0, or --classes 0 2 3")
    parser.add_argument("--agnostic-nms", action="store_true", help="class-agnostic NMS")
    parser.add_argument("--augment", action="store_true", help="augmented inference")
    parser.add_argument("--visualize", action="store_true", help="visualize features")
    parser.add_argument("--update", action="store_true", help="update all models")
    parser.add_argument("--project", default=ROOT / "runs/detect", help="save results to project/name")
    parser.add_argument("--name", default="exp", help="save results to project/name")
    parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
    parser.add_argument("--line-thickness", default=3, type=int, help="bounding box thickness (pixels)")
    parser.add_argument("--hide-labels", default=False, action="store_true", help="hide labels")
    parser.add_argument("--hide-conf", default=False, action="store_true", help="hide confidences")
    parser.add_argument("--half", action="store_true", help="use FP16 half-precision inference")
    parser.add_argument("--dnn", action="store_true", help="use OpenCV DNN for ONNX inference")
    parser.add_argument("--vid-stride", type=int, default=1, help="video frame-rate stride")
    opt = parser.parse_args()
    opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1  # expand
    print_args(vars(opt)) # 打印所有参数信息
    return opt


def main(opt):
    """Executes YOLOv5 model inference with given options, checking requirements before running the model."""
    # 检查环境是否都安装完毕,排除掉两个不检查
    check_requirements(ROOT / "requirements.txt", exclude=("tensorboard", "thop"))
    # 执行run函数,将命令行参数传进去
    run(**vars(opt))

# 命令使用:python detect.py --weights runs/train/exp9/weights/best.pt --source inference/images --conf-thres=0.9
if __name__ == "__main__":
    # 加载命令行参数
    opt = parse_opt()
    main(opt)

参考文章

  • 3
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值