YOLOv5的检测代码(中文注释)

YOLOv5的部署请移步这里http://t.csdnimg.cn/29Jor

# -*- coding: utf-8 -*-
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
"""
运行 YOLOv5 检测推理,适用于图像、视频、目录、globs、YouTube、网络摄像头、流等。

使用方法 - 数据源:
    $ python detect.py --weights yolov5s.pt --source 0                               # 网络摄像头
                                                     img.jpg                         # 图像
                                                     vid.mp4                         # 视频
                                                     screen                          # 截屏
                                                     path/                           # 目录
                                                     list.txt                        # 图像列表
                                                     list.streams                    # 流列表
                                                     'path/*.jpg'                    # glob
                                                     'https://youtu.be/LNwODJXcvt4'  # YouTube
                                                     'rtsp://example.com/media.mp4'  # RTSP, RTMP, HTTP 流

使用方法 - 格式:
    $ python detect.py --weights yolov5s.pt                 # PyTorch
                                 yolov5s.torchscript        # TorchScript
                                 yolov5s.onnx               # ONNX 运行时或 OpenCV DNN 使用 --dnn
                                 yolov5s_openvino_model     # OpenVINO
                                 yolov5s.engine             # TensorRT
                                 yolov5s.mlmodel            # CoreML (仅限 macOS)
                                 yolov5s_saved_model        # TensorFlow SavedModel
                                 yolov5s.pb                 # TensorFlow GraphDef
                                 yolov5s.tflite             # TensorFlow Lite
                                 yolov5s_edgetpu.tflite     # TensorFlow Edge TPU
                                 yolov5s_paddle_model       # PaddlePaddle
"""

import argparse  # 导入命令行参数解析模块
import csv  # 导入CSV文件操作模块
import os  # 导入操作系统接口模块
import platform  # 导入用于获取操作系统信息的模块
import sys  # 导入系统模块
from pathlib import Path  # 导入路径操作模块

import torch  # 导入PyTorch模块

FILE = Path(__file__).resolve()  # 获取当前文件的绝对路径
ROOT = FILE.parents[0]  # 获取YOLOv5的根目录
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # 将根目录添加到系统路径
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # 获取根目录的相对路径

from ultralytics.utils.plotting import Annotator, colors, save_one_box  # 导入绘图工具

from models.common import DetectMultiBackend  # 导入模型后端选择工具
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams  # 导入数据加载工具
from utils.general import (
    LOGGER,  # 导入日志工具
    Profile,  # 导入性能分析工具
    check_file,  # 导入文件检查工具
    check_img_size,  # 导入图像尺寸检查工具
    check_imshow,  # 导入显示检查工具
    check_requirements,  # 导入依赖检查工具
    colorstr,  # 导入颜色字符串工具
    cv2,  # 导入OpenCV模块
    increment_path,  # 导入路径增量工具
    non_max_suppression,  # 导入非最大抑制工具
    print_args,  # 导入参数打印工具
    scale_boxes,  # 导入框缩放工具
    strip_optimizer,  # 导入优化器剥离工具
    xyxy2xywh,  # 导入坐标转换工具
)
from utils.torch_utils import select_device, smart_inference_mode  # 导入设备选择和智能推理模式工具

@smart_inference_mode()
def run(
    weights=ROOT / "yolov5s.pt",  # 模型权重路径或Triton服务器的URL
    source=ROOT / "data/images",  # 输入源路径,可以是文件、目录、URL、glob模式或屏幕截图
    data=ROOT / "data/coco128.yaml",  # 数据集配置文件路径
    imgsz=(640, 640),  # 推理时的图像尺寸(高度,宽度)
    conf_thres=0.25,  # 置信度阈值,用于过滤检测结果
    iou_thres=0.45,  # 非最大抑制(NMS)的IOU阈值
    max_det=1000,  # 每张图像的最大检测目标数
    device="",  # 指定运行设备,如CUDA设备('0'或'0,1,2,3')或CPU('cpu')
    view_img=False,  # 是否显示检测结果图像
    save_txt=False,  # 是否将检测结果保存为txt文件
    save_csv=False,  # 是否将检测结果保存为CSV格式文件
    save_conf=False,  # 是否在txt文件中保存置信度
    save_crop=False,  # 是否保存裁剪出的检测目标图像
    nosave=False,  # 是否不保存检测结果图像/视频
    classes=None,  # 指定需要检测的类别索引列表
    agnostic_nms=False,  # 是否使用类别不可知的NMS
    augment=False,  # 是否使用数据增强进行推理
    visualize=False,  # 是否可视化网络特征
    update=False,  # 是否更新所有模型
    project=ROOT / "runs/detect",  # 检测结果保存的项目目录
    name="exp",  # 在项目目录中的子目录名称
    exist_ok=False,  # 如果项目/名称已存在,是否覆盖
    line_thickness=3,  # 绘制边框的线条粗细
    hide_labels=False,  # 是否隐藏标签
    hide_conf=False,  # 是否隐藏置信度
    half=False,  # 是否使用半精度FP16进行推理
    dnn=False,  # 是否使用OpenCV DNN进行ONNX推理
    vid_stride=1,  # 视频帧处理的步长
):
    source = str(source)
    save_img = not nosave and not source.endswith(".txt")  # 判断是否保存推理图像
    is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)  # 判断源是否为文件
    is_url = source.lower().startswith(("rtsp://", "rtmp://", "http://", "https://"))  # 判断源是否为URL
    webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)  # 判断源是否为网络摄像头
    screenshot = source.lower().startswith("screen")  # 判断源是否为屏幕截图
    if is_url and is_file:
        source = check_file(source)  # 如果是URL且指向文件,则下载文件

    # 创建保存目录
    save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)
    (save_dir / "labels" if save_txt else save_dir).mkdir(parents=True, exist_ok=True)

    # 加载模型
    device = select_device(device)
    model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
    stride, names, pt = model.stride, model.names, model.pt
    imgsz = check_img_size(imgsz, s=stride)  # 校验图像尺寸

    # 数据加载器
    bs = 1  # 批量大小
    if webcam:
        view_img = check_imshow(warn=True)
        dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
        bs = len(dataset)
    elif screenshot:
        dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
    else:
        dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
    vid_path, vid_writer = [None] * bs, [None] * bs

    # 运行推理
    model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))  # 预热模型
    seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))
    for path, im, im0s, vid_cap, s in dataset:
        with dt[0]:
            im = torch.from_numpy(im).to(model.device)
            im = im.half() if model.fp16 else im.float()  # 转换数据类型
            im /= 255  # 归一化图像数据
            if len(im.shape) == 3:
                im = im[None]  # 增加批次维度

        # 推理
        with dt[1]:
            visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
            if model.xml and im.shape[0] > 1:
                pred = None
                for image in ims:
                    if pred is None:
                        pred = model(image, augment=augment, visualize=visualize).unsqueeze(0)
                    else:
                        pred = torch.cat((pred, model(image, augment=augment, visualize=visualize).unsqueeze(0)), dim=0)
                pred = [pred, None]
            else:
                pred = model(im, augment=augment, visualize=visualize)

        # 非最大抑制
        with dt[2]:
            pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)

        # 定义CSV文件路径
        csv_path = save_dir / "predictions.csv"

        # 创建或追加到CSV文件
        def write_to_csv(image_name, prediction, confidence):
            """将图像的预测数据写入CSV文件,如果文件存在则追加。"""
            data = {"Image Name": image_name, "Prediction": prediction, "Confidence": confidence}
            with open(csv_path, mode="a", newline="") as f:
                writer = csv.DictWriter(f, fieldnames=data.keys())
                if not csv_path.is_file():
                    writer.writeheader()
                writer.writerow(data)

        # 处理预测结果
        for i, det in enumerate(pred):  # 每张图像
            seen += 1
            if webcam:  # 如果是网络摄像头
                p, im0, frame = path[i], im0s[i].copy(), dataset.count
                s += f"{i}: "
            else:
                p, im0, frame = path, im0s.copy(), getattr(dataset, "frame", 0)

            p = Path(p)  # 转换为路径对象
            save_path = str(save_dir / p.name)  # 构建保存路径
            txt_path = str(save_dir / "labels" / p.stem) + ("" if dataset.mode == "image" else f"_{frame}")  # 构建文本文件路径
            s += "%gx%g " % im.shape[2:]  # 打印图像尺寸
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # 获取归一化增益
            imc = im0.copy() if save_crop else im0  # 如果保存裁剪,则复制图像
            annotator = Annotator(im0, line_width=line_thickness, example=str(names))
            if len(det):
                # 从图像尺寸调整边框到原始尺寸
                det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()

                # 打印结果
                for c in det[:, 5].unique():
                    n = (det[:, 5] == c).sum()  # 每类的检测数量
                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # 添加到字符串

                # 写入结果
                for *xyxy, conf, cls in reversed(det):
                    c = int(cls)  # 类别索引
                    label = names[c] if hide_conf else f"{names[c]}"
                    confidence = float(conf)
                    confidence_str = f"{confidence:.2f}"

                    if save_csv:
                        write_to_csv(p.name, label, confidence_str)

                    if save_txt:  # 写入文件
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # 归一化xywh
                        line = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # 标签格式
                        with open(f"{txt_path}.txt", "a") as f:
                            f.write(("%g " * len(line)).rstrip() % line + "\n")

                    if save_img or save_crop or view_img:  # 在图像上添加边框
                        c = int(cls)  # 类别索引
                        label = None if hide_labels else (names[c] if hide_conf else f"{names[c]} {conf:.2f}")
                        annotator.box_label(xyxy, label, color=colors(c, True))
                    if save_crop:
                        save_one_box(xyxy, imc, file=save_dir / "crops" / names[c] / f"{p.stem}.jpg", BGR=True)

            # 流式传输结果
            im0 = annotator.result()
            if view_img:
                if platform.system() == "Linux" and p not in windows:
                    windows.append(p)
                    cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)  # 允许窗口调整大小(Linux)
                    cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
                cv2.imshow(str(p), im0)
                cv2.waitKey(1)  # 1毫秒

            # 保存结果(带检测框的图像)
            if save_img:
                if dataset.mode == "image":
                    cv2.imwrite(save_path, im0)
                else:  # 'video' 或 'stream'
                    if vid_path[i] != save_path:  # 新视频
                        vid_path[i] = save_path
                        if isinstance(vid_writer[i], cv2.VideoWriter):
                            vid_writer[i].release()  # 释放之前的视频写入器
                        if vid_cap:  # 视频
                            fps = vid_cap.get(cv2.CAP_PROP_FPS)
                            w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                            h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                        else:  # 流
                            fps, w, h = 30, im0.shape[1], im0.shape[0]
                        save_path = str(Path(save_path).with_suffix(".mp4"))  # 强制使用*.mp4后缀
                        vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
                    vid_writer[i].write(im0)

        # 打印时间(仅推理)
        LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")

    # 打印结果
    t = tuple(x.t / seen * 1e3 for x in dt)  # 每张图像的速度
    LOGGER.info(f"Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}" % t)
    if save_txt or save_img:
        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ""
        LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
    if update:
        strip_optimizer(weights[0])  # 更新模型(以修复SourceChangeWarning)

def parse_opt():
    """解析命令行参数,设置 YOLOv5 检测的推理选项和模型配置。"""
    parser = argparse.ArgumentParser()
    parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "yolov5s.pt", help="模型路径或 triton URL")
    parser.add_argument("--source", type=str, default=ROOT / "data/images", help="文件/目录/URL/glob/屏幕/0(网络摄像头)")
    parser.add_argument("--data", type=str, default=ROOT / "data/coco128.yaml", help="(可选) dataset.yaml 路径")
    parser.add_argument("--imgsz", "--img", "--img-size", nargs="+", type=int, default=[640], help="推理尺寸 h,w")
    parser.add_argument("--conf-thres", type=float, default=0.25, help="置信度阈值")
    parser.add_argument("--iou-thres", type=float, default=0.45, help="NMS IOU阈值")
    parser.add_argument("--max-det", type=int, default=1000, help="每张图像的最大检测数量")
    parser.add_argument("--device", default="", help="cuda 设备, 例如 0 或 0,1,2,3 或 cpu")
    parser.add_argument("--view-img", action="store_true", help="显示结果")
    parser.add_argument("--save-txt", action="store_true", help="保存结果到 *.txt")
    parser.add_argument("--save-csv", action="store_true", help="以CSV格式保存结果")
    parser.add_argument("--save-conf", action="store_true", help="在 --save-txt 标签中保存置信度")
    parser.add_argument("--save-crop", action="store_true", help="保存裁剪后的预测框")
    parser.add_argument("--nosave", action="store_true", help="不保存图像/视频")
    parser.add_argument("--classes", nargs="+", type=int, help="按类过滤: --classes 0, 或 --classes 0 2 3")
    parser.add_argument("--agnostic-nms", action="store_true", help="类别不可知的NMS")
    parser.add_argument("--augment", action="store_true", help="增强推理")
    parser.add_argument("--visualize", action="store_true", help="可视化特征")
    parser.add_argument("--update", action="store_true", help="更新所有模型")
    parser.add_argument("--project", default=ROOT / "runs/detect", help="将结果保存到 project/name")
    parser.add_argument("--name", default="exp", help="将结果保存到 project/name")
    parser.add_argument("--exist-ok", action="store_true", help="存在的 project/name 是可以的,不需要递增")
    parser.add_argument("--line-thickness", default=3, type=int, help="边框厚度 (像素)")
    parser.add_argument("--hide-labels", default=False, action="store_true", help="隐藏标签")
    parser.add_argument("--hide-conf", default=False, action="store_true", help="隐藏置信度")
    parser.add_argument("--half", action="store_true", help="使用 FP16 半精度推理")
    parser.add_argument("--dnn", action="store_true", help="使用 OpenCV DNN 进行 ONNX 推理")
    parser.add_argument("--vid-stride", type=int, default=1, help="视频帧率步长")
    opt = parser.parse_args()
    opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1  # 扩展
    print_args(vars(opt))
    return opt

def main(opt):
    """使用给定选项执行 YOLOv5 模型推理,运行模型前检查要求。"""
    check_requirements(ROOT / "requirements.txt", exclude=("tensorboard", "thop"))
    run(**vars(opt))

if __name__ == "__main__":
    opt = parse_opt()
    main(opt)
  • 5
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: YOLOv5是一种快速、高效的深度学习算法,用于对象检测。它的官方源代码提供了全中文注释,使得理解和使用变得更加容易。 官方源代码中的注释清晰明了,覆盖了每个函数的功能、输入及输出参数、返回值等重要信息。此外,注释还解释了源代码中的关键概念和算法,如Anchor boxes、Backbone、FPN等。 通过阅读并理解YOLOv5官方源代码注释,用户可以更好地掌握YOLOv5算法的核心思想和实现方式。在使用过程中,用户可以根据自己的需要进行改进和调整,实现更好的对象检测效果。 总之,YOLOv5官方源代码提供了全中文注释,为用户提供了可靠且易于理解的代码实现。通过研究和使用这些代码,用户可以更好地掌握深度学习算法,并开发出更加先进的应用。 ### 回答2: YOLOv5是一个高性能的目标检测算法,由于其强大的检测能力和在不同方向上的多任务训练能力而备受关注。为了使用户更好地了解该算法的具体实现,官方提供了全中文注释源码。 YOLOv5官方全中文注释源码包括了YOLOv5算法的所有核心代码,所有源码都有详细的中文注释,方便用户理解每一行代码的含义和作用。在这些代码注释中,用户可以了解到该算法的实现方式和各种技术细节,包括所有的网络结构、损失函数、数据集预处理方法、数据增强方法、评估指标等细节。 此外,中文注释源码中还包含一些有用的代码样例,方便用户快速上手和实践。通过访问官方的GitHub仓库,用户可以轻松地下载和使用所有源码和注释,以及预训练好的模型。 总的来说,YOLOv5官方全中文注释源码是一个非常有用的工具,它不仅让用户更加了解该算法的实现细节,还能够帮助用户在实践中解决性能问题和技术难题,是一个非常实用的资源。 ### 回答3: YOLOv5是一种基于深度学习的物体检测算法,它能够快速、准确地检测图像中的目标物体。现在,YOLOv5官方已经发布了全中文注释的源代码,这意味着我们可以更好地了解这个算法的工作原理以及如何在实际应用中进行调试和改进。 在这份全中文注释的源代码中,我们可以看到许多注释,涵盖了很多关键的算法细节和技术细节。这些注释让我们更好地理解代码的实现细节,例如图像预处理、网络架构、损失函数、目标检测的评估指标等等。这些注释还提供了一些实用的技巧和技能,帮助我们更好地理解和使用YOLOv5算法。 除了注释之外,官方给出了多个使用示例,这些示例可以帮助我们更好地了解如何使用YOLOv5算法进行物体检测。此外,官方还提供了一些训练好的模型,这些模型可直接用于自己的应用中,省去了自己训练模型的时间和精力。 总的来说,YOLOv5官方全中文注释源码是一个非常有用的开源资源,可以帮助我们更好地了解和使用该算法。在未来的应用中,我们可以用它来改进和优化自己的物体检测系统,以更好地满足实际需求。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值