detect.py是Ultralytics的YOLOv5目标检测框架的一部分,这个脚本用于在不同类型的源(例如图像、视频、目录、网络摄像头和流媒体)上运行检测推理。
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Run YOLOv5 detection inference on images, videos, directories, globs, YouTube, webcam, streams, etc.
Usage - sources:
$ python detect.py --weights yolov5s.pt --source 0 # webcam
img.jpg # image
vid.mp4 # video
screen # screenshot
path/ # directory
'path/*.jpg' # glob
'https://youtu.be/Zgi9g1ksQHc' # YouTube
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
Usage - formats:
$ python detect.py --weights yolov5s.pt # PyTorch
yolov5s.torchscript # TorchScript
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
yolov5s_openvino_model # OpenVINO
yolov5s.engine # TensorRT
yolov5s.mlmodel # CoreML (macOS-only)
yolov5s_saved_model # TensorFlow SavedModel
yolov5s.pb # TensorFlow GraphDef
yolov5s.tflite # TensorFlow Lite
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
yolov5s_paddle_model # PaddlePaddle
"""
import argparse
import os
import platform
import sys
from pathlib import Path
import torch
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, smart_inference_mode
@smart_inference_mode()
def run(
weights=ROOT / 'yolov5s.pt', # model path or triton URL
source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
imgsz=(640, 640), # inference size (height, width)
conf_thres=0.25, # confidence threshold
iou_thres=0.45, # NMS IOU threshold
max_det=1000, # maximum detections per image
device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
view_img=False, # show results
save_txt=False, # save results to *.txt
save_conf=False, # save confidences in --save-txt labels
save_crop=False, # save cropped prediction boxes
nosave=False, # do not save images/videos
classes=None, # filter by class: --class 0, or --class 0 2 3
agnostic_nms=False, # class-agnostic NMS
augment=False, # augmented inference
visualize=False, # visualize features
update=False, # update all models
project=ROOT / 'runs/detect', # save results to project/name
name='exp', # save results to project/name
exist_ok=False, # existing project/name ok, do not increment
line_thickness=3, # bounding box thickness (pixels)
hide_labels=False, # hide labels
hide_conf=False, # hide confidences
half=False, # use FP16 half-precision inference
dnn=False, # use OpenCV DNN for ONNX inference
vid_stride=1, # video frame-rate stride
):
source = str(source)
save_img = not nosave and not source.endswith('.txt') # save inference images
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
screenshot = source.lower().startswith('screen')
if is_url and is_file:
source = check_file(source) # download
# Directories
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
# Load model
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
stride, names, pt = model.stride, model.names, model.pt
imgsz = check_img_size(imgsz, s=stride) # check image size
# Dataloader
bs = 1 # batch_size
if webcam:
view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset)
elif screenshot:
dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
vid_path, vid_writer = [None] * bs, [None] * bs
# Run inference
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
for path, im, im0s, vid_cap, s in dataset:
with dt[0]:
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
# Inference
with dt[1]:
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred = model(im, augment=augment, visualize=visualize)
# NMS
with dt[2]:
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
# Second-stage classifier (optional)
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
# Process predictions
for i, det in enumerate(pred): # per image
seen += 1
if webcam: # batch_size >= 1
p, im0, frame = path[i], im0s[i].copy(), dataset.count
s += f'{i}: '
else:
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
p = Path(p) # to Path
save_path = str(save_dir / p.name) # im.jpg
txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
s += '%gx%g ' % im.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = im0.copy() if save_crop else im0 # for save_crop
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
# Print results
for c in det[:, 5].unique():
n = (det[:, 5] == c).sum() # detections per class
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
# Write results
for *xyxy, conf, cls in reversed(det):
if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
with open(f'{txt_path}.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
if save_img or save_crop or view_img: # Add bbox to image
c = int(cls) # integer class
label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
annotator.box_label(xyxy, label, color=colors(c, True))
if save_crop:
save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
# Stream results
im0 = annotator.result()
if view_img:
if platform.system() == 'Linux' and p not in windows:
windows.append(p)
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
cv2.imshow(str(p), im0)
cv2.waitKey(1) # 1 millisecond
# Save results (image with detections)
if save_img:
if dataset.mode == 'image':
cv2.imwrite(save_path, im0)
else: # 'video' or 'stream'
if vid_path[i] != save_path: # new video
vid_path[i] = save_path
if isinstance(vid_writer[i], cv2.VideoWriter):
vid_writer[i].release() # release previous video writer
if vid_cap: # video
fps = vid_cap.get(cv2.CAP_PROP_FPS)
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else: # stream
fps, w, h = 30, im0.shape[1], im0.shape[0]
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
vid_writer[i].write(im0)
# Print time (inference-only)
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
# Print results
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
if save_txt or save_img:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
if update:
strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning)
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model path or triton URL')
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--view-img', action='store_true', help='show results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--visualize', action='store_true', help='visualize features')
parser.add_argument('--update', action='store_true', help='update all models')
parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
parser.add_argument('--name', default='exp', help='save results to project/name')
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
print_args(vars(opt))
return opt
def main(opt):
check_requirements(exclude=('tensorboard', 'thop'))
run(**vars(opt))
if __name__ == "__main__":
opt = parse_opt()
main(opt)
以下是脚本组件和功能的简要概述:
- 导入依赖项: 脚本首先导入必要的库和模块,包括PyTorch、argparse(用于解析命令行参数)以及YOLOv5仓库中的各种实用函数。
- 设置和配置: 它通过将YOLOv5的根目录添加到系统路径和准备必要的常量和文件路径来设置环境。
- 检测函数(run): 处理检测的核心函数。它接受多个参数,如权重(模型路径)、源(输入数据)、imgsz(图像大小)、device(CUDA或CPU)以及用于推理的其他各种配置。
- 数据加载: 根据源的不同,脚本会为图像、视频、网络摄像头流或屏幕截图创建数据加载器。
- 模型加载: 它根据weights参数加载指定的目标检测模型,并为推理设置设备(GPU或CPU)。
- 推理循环: 脚本遍历加载的数据并执行以下步骤:
- 图像预处理(调整大小、标准化)
- 使用加载的模型进行推理
- 非极大值抑制(NMS)以过滤掉重叠的边界框
- 可选的检测可视化和结果保存
- 保存结果: 检测到的边界框可以使用–save-txt标志保存为图像或文本文件。还可以使用–save-crop标志保存检测到的裁剪图像。
- 参数解析(parse_opt): 脚本使用argparse定义并解析命令行参数。它允许用户自定义运行检测的各种设置。
- **主入口点:**main函数处理解析后的参数,并通过提供的选项调用run函数来启动检测过程。
要使用这个脚本,您需要安装YOLOv5及其必要的依赖项(PyTorch、OpenCV等)。您可以在命令行中调用该脚本,并为您希望执行的任务提供适当的参数。
例如,要使用YOLOv5s模型在网络摄像头上运行检测,可以执行以下命令:
python detect.py --weights yolov5s.pt --source ../../12-1.mp4 --conf-thres 0.25 --iou-thres 0.45
主入口点
在Python脚本中,main函数通常作为程序的主入口点,负责处理命令行参数并启动程序的主要逻辑。以下是一个示例,说明如何在脚本中定义main函数,处理解析后的参数,并调用run函数来启动检测过程:
import argparse
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='模型权重文件路径')
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='图像/视频源')
parser.add_argument('--img-size', type=int, default=[640, 640], help='推理时的图像大小')
parser.add_argument('--conf-thres', type=float, default=0.25, help='目标置信度阈值')
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS中的IOU阈值')
# ...更多参数
opt = parser.parse_args()
return opt
def run(
weights=ROOT / 'yolov5s.pt', # model path or triton URL
source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
imgsz=(640, 640), # inference size (height, width)
# ... 其他参数 ...
):
def main(opt):
# 主函数,它接收命令行参数对象,并使用这些参数来执行程序的主要逻辑
print(f"Running detection with options: {opt}")
check_requirements(exclude=('tensorboard', 'thop'))
run(**vars(opt)) # 调用run函数,传递解析后的参数
if __name__ == "__main__":
# 当脚本被直接运行时,这个代码块会被执行
opt = parse_opt() # 调用parse_opt函数来解析命令行参数
main(opt) # 调用main函数并传入解析后的参数对象
在这个示例脚本中,main函数的作用是:
- 接收参数: main函数的参数opt是一个包含了所有命令行参数值的对象,通常由parse_opt函数返回。
- 打印信息: print(f"Running detection with options: {opt}")这行代码是一个简单的日志,用于打印当前的运行配置信息。
- 调用检测函数: run()这行代码调用了run函数,并将解析后的参数对象opt传递给它。这样做是为了启动实际的检测过程,run函数内部实现了加载模型、处理数据和执行推理等操作。
- 入口点检查: if name == “main”:这行代码检查脚本是否是作为主程序直接运行而不是作为模块被导入。如果是直接运行,会执行以下动作:
- opt = parse_opt(): 调用parse_opt函数来从命令行解析参数。
- main(opt): 解析完成后,将参数传递给main函数,并执行程序的主逻辑。
- **check_requirements:**check_requirements函数的作用是检查当前Python环境中是否已经安装了指定的依赖包,以确保脚本或程序能够正常运行。通常,这个函数会在脚本的开始处调用,作为环境检查的一部分。
函数的参数exclude是一个包含要排除在外的依赖包名称的元组。在这个例子中,exclude=(‘tensorboard’, ‘thop’)表示在检查时不考虑tensorboard和thop这两个包,即使它们在依赖列表中,也不会验证它们是否安装。
使用 check_requirements 函数可以确保在运行程序之前所有必要的依赖都已经满足,防止因缺少依赖而导致的运行时错误。同时,提供排除选项可以给予开发者更多的灵活性,使他们可以根据具体情况决定哪些依赖是必要的,哪些可以忽略。
- **run(vars(opt)): 在Python中,**操作符用于将字典拆解成关键字参数传递给函数。当你有一个字典,并且你想要将它的键值对作为关键字参数传递给一个函数时,这个操作符会很有用。
vars()函数是Python内置函数,它返回一个对象的__dict__属性。对于一个由argparse解析后的参数命名空间对象,vars()函数会返回一个包含所有参数的字典。每个参数名作为字典的键,对应的参数值作为字典的值。
结合在一起,run(**vars(opt))这行代码的作用是将通过argparse解析得到的命令行参数(包含在opt对象中),转换成一个参数字典,然后将这个字典拆解成多个关键字参数传递给run函数。
举个例子,假设你有以下代码:
import argparse
def run(a, b, c):
print(f'a = {a}, b = {b}, c = {c}')
parser = argparse.ArgumentParser()
parser.add_argument('--a', type=int)
parser.add_argument('--b', type=int)
parser.add_argument('--c', type=int)
opt = parser.parse_args(['--a', '1', '--b', '2', '--c', '3'])
run(**vars(opt))
在这个例子中,vars(opt)会返回一个字典{‘a’: 1, ‘b’: 2, ‘c’: 3}。run(**vars(opt))将这个字典转换成run(a=1, b=2, c=3)的调用。这种方式简化了将一个对象的属性作为参数传递给函数的过程,特别是当参数数量较多时。
参数解析(parse_opt)
在Python脚本中,argparse库用于添加和解析命令行参数,允许用户在运行脚本时自定义配置。以下是使用argparse来定义和解析命令行参数的一个示例:
import argparse
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='模型权重文件路径')
parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='图像/视频源')
parser.add_argument('--img-size', type=int, default=[640, 640], help='推理时的图像大小')
parser.add_argument('--conf-thres', type=float, default=0.25, help='目标置信度阈值')
parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS中的IOU阈值')
# ...更多参数
opt = parser.parse_args()
return opt
def main(opt):
# 使用解析后的参数
print(f"Model weights: {opt.weights}")
# ...调用run函数或其他函数
if __name__ == "__main__":
opt = parse_opt()
main(opt)
这段代码中的主要部分包括:
- 创建解析器 (argparse.ArgumentParser): 这个对象用于处理命令行参数。
- 添加参数 (add_argument): 使用add_argument方法为每个期望的命令行参数定义规则。比如,参数名称、类型、默认值和帮助信息等。
- 解析参数 (parse_args): 通过调用parse_args方法来获取命令行输入的参数。这会返回一个包含所有参数值的命名空间对象。
- 使用参数: 获取的参数可以被直接传递给main函数或其他任何需要这些参数的函数。
在这个例子中,–weights参数用于设置模型权重文件的路径,–source参数用于设置输入源,–img-size用于设置模型的输入图像尺寸,–conf-thres和–iou-thres分别用于设置目标检测的置信度阈值和NMS中的IOU阈值。
parse_opt函数的目的是在脚本开始执行时解析所有的命令行参数,并返回一个对象,这个对象包含了所有命令行参数的值。之后,这些参数将被用于配置目标检测过程。
导入依赖项
代码中的导入依赖项部分主要是为了确保脚本在执行时能够调用到YOLOv5的模型、工具函数以及其他一些必需的Python库。下面是详细说明:
import argparse # 用于解析命令行参数
import os # 提供了与操作系统进行交互的功能,比如文件路径操作
import platform # 用来获取当前系统的信息
import sys # 提供了一系列关于Python运行环境的变量和函数
from pathlib import Path # 提供面向对象的文件系统路径操作
import torch # 导入PyTorch库,用于深度学习模型的运算
# 从YOLOv5的各个模块中导入所需的类和函数
from models.common import DetectMultiBacked # 导入多后端检测模型类
# 导入数据加载相关的类和常量
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import ( # 导入通用函数和类
LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements,
colorstr, cv2, increment_path, non_max_suppression, print_args, scale_boxes,
strip_optimizer, xyxy2xywh
)
from utils.plots import Annotator, colors, save_one_box # 导入用于标注和保存图像的工具函数
from utils.torch_utils import select_device, smart_inference_mode # 导入与PyTorch相关的工具函数
这些导入项包含了运行YOLOv5检测脚本所必须的各种工具函数和类。以下是一些关键模块的用途:
- DetectMultiBackend: 用于加载和使用不同后端(如PyTorch、TorchScript、ONNX等)的目标检测模型。
- LoadImages, LoadStreams, LoadScreenshots: 用于从不同的源(如文件、视频流或屏幕截图)加载输入数据。
- check_file, check_img_size, non_max_suppression, xyxy2xywh: 提供图像预处理、后处理等功能,用于确保输入图像符合模型要求,以及对模型输出进行非极大值抑制和坐标转换。
- LOGGER: 用于日志记录。
- Annotator, colors, save_one_box: 用于在图像上绘制边界框和类别标签,并在需要时保存检测结果。
此外,这段代码还包含了一些智能推理模式和设备选择的辅助函数,如select_device和smart_inference_mode,这些函数帮助脚本在不同的硬件上高效运行(例如,自动在GPU和CPU之间切换)。
总的来说,导入这些依赖项是为了确保脚本有完整的功能,能够处理图像输入、执行检测、处理结果并与操作系统交互。
设置和配置
设置和配置部分的代码主要关注在环境准备、路径处理和模块导入。这是为了确保脚本能够正确地找到模型文件、数据集配置,以及其他资源,并能够顺利运行。下面逐一解释这部分代码的作用:
# 获取当前文件的绝对路径,并解析为Path对象
FILE = Path(__file__).resolve()
# 获取YOLOv5根目录的绝对路径,即当前脚本的父目录
ROOT = FILE.parents[0]
# 将YOLOv5根目录添加到系统环境变量中,如果还没有添加的话
# 这样做是为了确保脚本中的后续导入语句可以正确找到并导入根目录下的模块
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
# 将ROOT转换为相对于当前工作目录的相对路径
# 这有助于在不同的文件系统或操作系统中更灵活地使用脚本
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))
- **解析文件路径:**Path(file).resolve() 获取当前脚本文件的绝对路径,并确保这个路径是解析后的路径,这意味着它会遵循所有的符号链接,并返回最终指向的文件路径。
- **确定根目录:**ROOT = FILE.parents[0] 利用Path对象的parents属性找到当前脚本的父目录,这通常是YOLOv5项目的根目录,因为detect.py通常位于项目的根目录。
- **添加到系统路径:**if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) 将YOLOv5的根目录添加到Python的系统路径sys.path中。这样做是为了确保在运行脚本时可以从任何地方导入项目根目录下的模块。
- **计算相对路径:**ROOT = Path(os.path.relpath(ROOT, Path.cwd())) 将根目录的绝对路径转换为相对于当前工作目录(你运行脚本的位置)的相对路径。
这整个设置和配置过程确保了无论YOLOv5项目的位置如何,脚本总是可以正确地找到并导入需要的模块和文件。它使代码更加灵活和可移植,无论是在开发环境中还是在部署时。
检测函数(run)
检测函数 run 是脚本中的核心部分,负责实现整个目标检测流程。它接受一系列参数,用以配置模型权重、输入源、图像大小、检测阈值等。下面是详细的代码说明:
@smart_inference_mode()
def run(
weights=ROOT / 'yolov5s.pt', # model path or triton URL
source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
data=ROOT / 'data/coco128.yaml', # dataset.yaml path
imgsz=(640, 640), # inference size (height, width)
# ... 其他参数 ...
):
# 将输入源转换为字符串格式
source = str(source)
# 定义是否保存图像的标志,nosave=Ture的时就不保存,默认保存
save_img = not nosave and not source.endswith('.txt') # save inference images
# 检查输入源是否为文件或视频
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
# 检查输入源是否为流媒体URL
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
# 检查输入源是否为网络摄像头
webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
# 检查输入源是否为屏幕截图
screenshot = source.lower().startswith('screen')
# 如果输入源是网络URL且为文件,检查文件是否需要下载
if is_url and is_file:
source = check_file(source) # download
# 创建用于保存结果的目录
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
# 加载模型
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
stride, names, pt = model.stride, model.names, model.pt
imgsz = check_img_size(imgsz, s=stride) # check image size
# 数据加载器的设置
bs = 1 # batch_size
if webcam:
# 检查是否支持显示图像
view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset)
elif screenshot:
dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
# 初始化视频路径和视频写入器
vid_path, vid_writer = [None] * bs, [None] * bs
# 模型预热
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
# 推理和后处理的时间统计
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
# 推理循环,遍历数据集
for path, im, im0s, vid_cap, s in dataset:
# 以下是推理和后处理的具体实现,
# 包括预处理图像、进行模型推理、应用NMS、
# 将结果绘制到图像上、保存图像或视频等。
# Print time (inference-only)
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
# 打印总结信息
LOGGER.info(f'Speed: ...')
if save_txt or save_img:
LOGGER.info(f"Results saved to {save_dir}")
# 更新模型(可选)
if update:
strip_optimizer(weights[0]) # update model
该函数使用装饰器 @smart_inference_mode(),这个smart_inference_mode装饰器的目的是为了智能地应用最合适的推理模式上下文管理器,以提升PyTorch模型在进行推理时的性能和效率。具体来说,根据PyTorch的版本是否大于等于1.9.0,它会选择使用torch.inference_mode或torch.no_grad。
让我们逐一看看这个装饰器的关键部分:
- 版本检查 (check_version(torch.version, ‘1.9.0’)): 这个调用会检查当前PyTorch的版本是否大于等于1.9.0。
- 装饰器选择: 如果PyTorch版本大于等于1.9.0,torch_1_9将为True,装饰器将使用torch.inference_mode;如果版本低于1.9.0,将使用torch.no_grad。
- decorate函数: 这是内部定义的装饰器函数,它接受一个函数fn作为参数,并返回应用了相应上下文管理器的新函数。
下面是具体的工作流程:
- 如果版本满足(torch >= 1.9.0),torch.inference_mode会被应用。这是在PyTorch 1.9.0版本中引入的一个新特性,它不仅禁用了梯度计算(和torch.no_grad一样),还进一步优化了某些操作以提高推理效率。
- 如果版本不满足(torch < 1.9.0),则使用老版本的torch.no_grad上下文管理器,它仅仅禁用了梯度计算,减少了内存使用并加快了计算速度,但没有torch.inference_mode中的额外优化。
这样,smart_inference_mode装饰器能够根据PyTorch的版本智能地选择最佳的推理模式,确保代码能在不同版本的PyTorch上以最优的方式运行。这对于写出与版本无关的、性能良好的代码尤其有用。使用该装饰器的方式大概如下:
@smart_inference_mode()
def my_inference_function(x):
# ... 执行推理 ...
在这里my_inference_function在执行时,将会处于适应其PyTorch版本的推理模式中。
以下是您提供的run函数参数的说明:
- weights: 模型权重文件路径或Triton URL。默认是ROOT目录中的’yolov5s.pt’。
- source: 输入数据源。可以是文件路径、目录、URL、glob模式、'screen’用于屏幕捕获,或’0’代表使用网络摄像头。
- data: 数据集的YAML配置文件路径。默认是ROOT目录中的’coco128.yaml’。
- imgsz: 推理时使用的图像大小,以(高度, 宽度)的元组形式给出。
- conf_thres: 置信度阈值,用于过滤检测结果,低于此值的检测会被丢弃。
- iou_thres: 非最大抑制(NMS)的IOU阈值。用于消除重叠的边界框。
- max_det: 每张图像允许的最大检测数量。
- device: 用于推理的计算设备,例如’cpu’或表示GPU设备ID的字符串,如’0’或’0,1,2,3’。
- view_img: 是否在窗口中显示结果。适用于实时可视化。
- save_txt: 是否将检测结果保存到文本文件(*.txt)。
- save_conf: 当save_txt为True时,是否在保存的文本文件中包含置信度分数。
- save_crop: 是否将裁剪后的检测框保存为图像文件。
- nosave: 如果为True,则不保存结果图像或视频。
- classes: 如果指定,仅过滤给定类别ID列表的检测结果。
- agnostic_nms: 如果为True,执行类别不可知的NMS,这在NMS过程中合并了所有类别的边界框。
- augment: 如果为True,使用增强推理技术,可能提升检测性能。
- visualize: 如果为True,可视化神经网络特征以帮助理解和调试。
- update: 如果为True,更新所有模型至最新可用版本。
- project: 结果保存的基本目录。每次推理运行将在该路径下创建一个新目录。
- name: 在project下保存结果的特定子目录名称。
- exist_ok: 如果为True,当项目目录中已存在同名目录时,不创建一个新的编号目录。
- line_thickness: 边界框线条的厚度(像素)。
- hide_labels: 如果为True,不在检测框上绘制标签(类名)。
- hide_conf: 如果为True,不在检测框上显示置信度分数。
- half: 如果为True,使用FP16半精度推理,这可以在兼容的GPU上加快计算速度。
- dnn: 如果为True,使用OpenCV的DNN模块进行ONNX模型推理。
- vid_stride: 视频帧速率的步长。如果大于1,则仅处理视频源的每第n帧。
这些参数允许用户定制检测过程的行为,包括输入数据、模型、输出、可视化和性能选项。通过调整这些参数,您可以控制检测阈值、检测的类别、输出格式和计算设备的使用等各个方面。
数据加载
在run函数中,数据加载部分的代码负责根据输入源(source参数)的不同类型(如文件、目录、网络摄像头、屏幕截图或视频流)来创建合适的数据加载器。这些数据加载器将会为推理过程逐个提供图像。下面来分析这部分代码:
# 将source参数转换为字符串,以便后续处理
source = str(source)
# 确定是否保存推理图像,取决于nosave标志和输入源是否为.txt文件
save_img = not nosave and not source.endswith('.txt')
# 检查输入源是否为文件或视频格式
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
# 检查输入源是否为网络流(RTSP, RTMP, HTTP, HTTPS)
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
# 确定输入源是否为网络摄像头或流媒体
webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
# 确定输入源是否为屏幕截图
screenshot = source.lower().startswith('screen')
# 如果输入源是网络流且是文件格式,检查并下载文件
if is_url and is_file:
source = check_file(source) # download
# 创建保存推理结果的目录
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
# 加载模型
# 设置batch size的初始值为1
bs = 1 # batch_size
# 根据输入源的类型,创建不同的数据加载器
if webcam: # 如果输入源是摄像头或者视频流
view_img = check_imshow(warn=True) # 检查是否能够实时显示图像
# 创建用于从摄像头或视频流加载数据的数据加载器
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset) # 如果输入源来自多个摄像头,batch size可能大于1
elif screenshot: # 如果输入源是屏幕截图
# 创建用于捕获屏幕截图的数据加载器
dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
else: # 如果输入源是文件或目录
# 创建用于从文件或目录加载图像的数据加载器
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
# 初始化视频路径和视频写入器的列表,用于保存视频结果
vid_path, vid_writer = [None] * bs, [None] * bs
让我们逐一看看这部分代码的功能:
- 设置默认的批处理大小 (bs):批处理大小即每次处理的图像数量,默认设置为1,这意味着系统将逐一处理每张图像。
- 选择数据加载器:根据输入源类型,脚本将选择合适的数据加载器以获取图像。
- 网络摄像头或视频流 (webcam): 使用LoadStreams类,支持多摄像头输入,能够从摄像头或实时视频流加载图像。如果输入源源自多个摄像头,bs(批处理大小)将设置为摄像头的数量。
- 屏幕截图 (screenshot): 使用LoadScreenshots类来直接从屏幕进行截图。
- 文件或目录: 若输入源是文件系统中的一个文件或目录,使用LoadImages类读取图像。这适用于处理单个图像文件、一组图像文件,或者一个包含图像文件的目录。
- 视频输出的准备:vid_path和vid_writer列表被初始化,用于保存视频推理结果。如果处理的是视频文件或者摄像头输入,这两个列表将用于存储每个视频源的输出路径和视频写入对象。
该数据加载部分是模型推理之前的准备过程的一部分,确保无论输入的是单张图片、一系列图片、视频还是实时流,都可以被适当读取,并以合适的格式提供给模型进行后续的推理处理。
vid_path, vid_writer = [None] * bs, [None] * bs
- [None] * bs 通过列表乘法创建一个新列表,该列表由 bs 个 None 组成。列表乘法是一种快捷方式,用于初始化一个具有固定长度并用相同元素填充的列表。
- vid_path 变量被赋值为第一个 [None] * bs 列表,而 vid_writer 变量被赋值为第二个 [None] * bs 列表。这意味着 vid_path 和 vid_writer 都是长度为 bs 的列表,它们的每个元素都初始化为 None。
这样的初始化可能用在一些需要跟踪或管理多个视频路径或视频写入器(例如,用于保存视频输出)的场合。在这种情况下,vid_path 可能用于存储每个输出视频的路径,而 vid_writer 可能是 cv2.VideoWriter 对象的列表,用于实际写入视频帧到文件。初始化为 None 可以作为一个占位符,表明列表中的每个位置还未被赋予实际的路径或 VideoWriter 对象,通常后续会在处理视频的过程中被逐个替换为具体的值。
模型加载部分
在run函数中,模型加载部分的代码负责初始化并加载指定的目标检测模型,准备好设备(如CPU或GPU),并根据提供的权重文件(weights参数)来配置模型。以下是该部分代码的详细解释:
# 选择设备,可以是CPU或GPU
device = select_device(device)
# 加载模型,DetectMultiBackend是一个可以根据提供的权重参数选择不同后端的包装器
# 支持PyTorch、ONNX、TensorRT等多种格式的权重文件
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
# 获取模型相关的信息,用于后续处理
# stride是模型在处理图像时的步长
# names是模型能够识别的类别名称列表
# pt是一个标志,如果模型是PyTorch格式,则为True
stride, names, pt = model.stride, model.names, model.pt
# 校验图像大小是否适合模型的步长要求
# 检查是否需要调整图像大小以匹配模型的步长,防止大小不兼容的情况发生
imgsz = check_img_size(imgsz, s=stride)
逐一分析这些代码部分的功能:
- 选择设备 (select_device): 该函数根据device参数的值来选择合适的设备。如果参数是’cpu’,它会选择CPU;如果参数是一个数字(如’0’),则会尝试选择对应的GPU设备。如果没有指定device,它会自动选择是否使用GPU。
- 加载模型 (DetectMultiBackend): DetectMultiBackend类是YOLOv5框架中用于支持多种模型格式的通用接口。这个类可以根据权重文件的类型(例如PyTorch模型.pt,ONNX模型.onnx等)来加载不同的模型。它还处理了是否使用半精度(FP16)运算以及其他模型相关配置。
- 获取模型信息: 从加载的模型中获取步长stride、类别名称列表names和一个布尔值pt,后者指示是否使用PyTorch模型。
- 校验并设置图像大小 (check_img_size): 这个函数确保输入图像的尺寸兼容模型的步长要求。步长对于卷积神经网络来说是图像在通过模型时采样的距离。为了使图像尺寸与步长相匹配,可能需要调整图像的大小。这样做是为了提高性能和减少可能的错误。
通过这些步骤,脚本将模型准备好并加载到了指定的计算设备上,使其随时准备进行图像的推理处理。在实际运行检测之前,这是一个关键的准备阶段。
推理循环
在run函数中的推理循环部分,代码实现了对加载的数据进行逐帧处理的功能。这包括图像的预处理、使用模型进行推理、应用非极大值抑制(NMS)处理重叠的边界框,并根据用户的配置选择性地进行结果可视化和保存。以下是这段代码的详细解释:
# 推理和后处理的时间统计
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
# 开始对加载的数据集中的每一帧图像进行处理
for path, im, im0s, vid_cap, s in dataset:
with dt[0]: # 开始记录预处理时间
# 转换im为PyTorch tensor,发送到指定的设备(CPU或GPU)
im = torch.from_numpy(im).to(device)
im = im.half() if model.fp16 else im.float() # 转换为半精度浮点(如果启用)或全精度浮点
im /= 255 # 归一化图像从[0, 255]到[0.0, 1.0]
# 在图像维度上添加一个batch维度,预测时需要这样
if len(im.shape) == 3:
im = im[None] # 添加batch维度
# 推理阶段
with dt[1]: # 开始记录推理时间
# 如果启用了visualize,则设置可视化路径
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
# 进行模型推理,传入图像和其他参数
pred = model(im, augment=augment, visualize=visualize)
# NMS阶段
with dt[2]: # 开始记录NMS时间
# 应用非极大值抑制(NMS),过滤掉一些重叠的边界框
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
# 处理预测结果
for i, det in enumerate(pred): # 遍历每张图像的检测结果
# ...省略:这里包含了针对每个检测到的目标的处理代码...
# 这些代码负责将检测结果转换为图片上的坐标,打印类别信息,写入结果文件,
# 绘制边界框和标签,保存视频帧等。
# 打印时间统计信息,推理时间仅包括模型的实际推理过程
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
让我们分解这个循环中的关键部分:
- 预处理:每帧图像im从NumPy数组转换为PyTorch张量,并被发送到指定的计算设备上。图像数据是从整数(0到255)归一化到浮点数(0.0到1.0)。如果图像只有三个维度(高度、宽度、颜色通道),还会添加一个批处理维度。
- 推理:使用模型(model)进行推理。如果设置了augment,将使用数据增强进行推理,visualize参数允许可视化模型的内部特征。
- 非极大值抑制(NMS):对模型的预测结果pred应用NMS,以去除重叠的边界框。这一步是根据置信度阈值(conf_thres)、IOU阈值(iou_thres)、类别(classes)、是否进行类别不可知NMS(agnostic_nms)和最大检测数(max_det)进行配置的。
- 处理预测结果:遍历每个图像的检测结果,进行后续处理如绘制边界框、保存检测结果到文件等。根据配置,可以保存标
- seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
这行代码在 Python 中执行了三个关键操作:
- 初始化变量 seen 并赋值为 0。
- 创建一个空列表 windows。
- 创建一个包含三个 Profile 实例的元组并赋值给变量 dt。
具体解释如下:
- seen = 0:这里定义了一个整型变量 seen 并初始化为 0。seen 变量通常用于跟踪已经处理过的数据量,例如在数据处理、阅读文件或者图像处理等任务中计数。
- windows = []:创建了一个名为 windows 的空列表。列表在 Python 中是用来存储一系列有序项目的容器。windows 可能会用于存储窗口对象或者其他类型的数据,具体取决于上下文。
- dt = (Profile(), Profile(), Profile()):创建了一个名为 dt 的元组,该元组包含三个 Profile 类的新实例。元组是一个不可变的序列类型,在 Python 中用圆括号 () 表示。Profile 很可能是在代码中定义的一个类,专注于性能分析,用于监视和跟踪代码执行的时间。因为这里创建了三个实例,可能对应于监视不同部分的性能,例如,加载时间、处理时间和保存/输出时间。
紧接着,代码进入了一个 for 循环,遍历 dataset 数据集中的每个元素。每个元素包含五个变量:path, im, im0s, vid_cap, 和 s,分别代表路径、处理后的图像数据、原始图像数据、视频捕获对象和状态信息。
在循环体中,有三个主要的步骤,每个步骤都包裹在 with dt[i] 上下文管理器内,这表示使用 Profile 对象测量每个步骤的执行时间:
- 数据准备
1with dt[0]:
2 im = torch.from_numpy(im).to(model.device)
3 im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
4 im /= 255 # 0 - 255 to 0.0 - 1.0
5 if len(im.shape) == 3:
6 im = im[None] # expand for batch dim
这一步骤将 NumPy 数组格式的图像 im 转换为 PyTorch 张量,并将其移动到模型所在的设备上(例如,GPU)。然后根据模型是否运行在半精度模式(fp16),将图像数据转换为相应的数据类型,并将像素值从范围 [0, 255] 归一化到 [0.0, 1.0]。如果图像是单张(第三维是通道数,没有批次维度),则在它的前面添加一个新的批次维度。
- 推理
1with dt[1]:
2 visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
3 pred = model(im, augment=augment, visualize=visualize)
在这一步骤中,首先根据路径 path 和是否要进行可视化 visualize 的标志,确定可视化路径。然后,对预处理后的图像 im 进行推理,获取模型预测结果 pred。可能还应用了数据增强 augment 和可视化。
- 非最大抑制(Non-Maximum Suppression, NMS)
1with dt[2]:
2 pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
这一步骤使用非最大抑制算法对模型的预测结果进行处理,以去除重叠的检测框并保留最佳的检测结果。相关参数包括置信度阈值 conf_thres、交并比阈值 iou_thres、特定类别 classes、是否进行类别不敏感处理 agnostic_nms 以及最大检测数量 max_det。
总结来说,这段代码展示了一个典型的图像处理流程,包括数据准备、模型推理和后处理步骤,同时使用 Profile 对象来测量每个步骤的执行时间,以便于性能分析。
保存结果
在run函数中,保存结果的部分代码负责根据用户的输入参数决定是否以文本或图像的形式保存检测到的边界框,同时也可以选择是否保存检测到目标的裁剪图像。以下是这段代码的详细解释:
# 处理预测结果
for i, det in enumerate(pred): # 遍历每张图像的检测结果
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
p = Path(p) # 转换为Path对象以便进行路径操作
save_path = str(save_dir / p.name) # 构造保存图像的路径
txt_path = str(save_dir / 'labels' / p.stem) # 构造保存文本的路径
# ... 其他代码 ...
# 如果检测到目标,则进行处理
if len(det):
# 将检测框坐标从模型输出大小映射回原始图像大小
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
# Print results
for c in det[:, 5].unique():
n = (det[:, 5] == c).sum() # detections per class
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
# 遍历检测到的每个目标
for *xyxy, conf, cls in reversed(det):
if save_txt: # 如果设置了保存文本标志
# 将边界框坐标转换为xywh格式,并归一化
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
# 格式化文本行
line = (cls, *xywh, conf) if save_conf else (cls, *xywh)
# 保存检测结果到文本文件
with open(f'{txt_path}.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
if save_img or view_img: # 如果设置了保存图像或显示图像
c = int(cls) # 获取类别编号
# 生成标签文本
label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
# 在原图上绘制边界框和标签
annotator.box_label(xyxy, label, color=colors(c, True))
if save_crop: # 如果设置了保存裁剪图像
# 根据边界框裁剪目标并保存
save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
# ... 其他代码 ...
# 如果设置了保存文本或图像结果,打印结果保存位置信息
if save_txt or save_img:
LOGGER.info(f"Results saved to {save_dir}")
让我们分解这个过程中的关键部分:
- 保存文本文件 (save_txt):如果用户通过命令行参数–save-txt指定保存检测结果为文本文件,代码会将每个检测到的目标的信息(类别、坐标和置信度)保存到指定路径下的.txt文件中。如果还指定了–save-conf,则置信度也会被包含在文本文件中。
- 保存图像 (save_img):如果用户设置了保存检测图像的选项,代码会在原始图像上绘制边界框和类别标签,并将带有检测结果的图像保存到指定路径。这通常是通过不设置–nosave选项来实现的。
- 保存裁剪图像 (save_crop):如果用户通过–save-crop指定了需要保存裁剪的检测目标图像,代码会从原始图像中裁剪出检测到的目标,并将它们保存到指定的目录中。
在代码的这部分,通过组合不同的命令行参数,可以灵活地保存检测结果的多种形式,无论是用于进一步分析的文
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
这行代码是在目标检测后处理过程中使用的,其目的是将检测到的边界框的尺寸从模型输入图像的大小(img_size)调整回原始图像的大小(im0 size)。在目标检测任务中,这是一个常见的步骤,因为在推理之前,图像通常会被缩放到模型要求的输入尺寸。为了在原始图像上正确地展示检测结果,需要将检测到的边界框的坐标重新缩放到原始图像的尺寸。
让我们分解这段代码:
det[:, :4]:
- det 是一个二维数组,其中每行代表一个检测结果,包括边界框的坐标、置信度和类别ID。
- [:, :4] 表示选取所有行(所有检测结果)的前四列,这四列分别对应边界框的 x1, y1, x2, y2(左上和右下角坐标)。
scale_boxes(im.shape[2:], det[:, :4], im0.shape):
- 这是一个函数调用,scale_boxes 函数负责将边界框坐标从当前尺寸(模型输入尺寸)调整到另一个尺寸(原始图像尺寸)。
- im.shape[2:] 获取模型输入图像的高度和宽度。
- det[:, :4] 是需要调整尺寸的边界框坐标。
- im0.shape 获取原始图像的尺寸,通常包括通道数、高度和宽度。
.round():
- 这是对 scale_boxes 函数返回的浮点数坐标进行四舍五入的操作,以得到整数像素坐标。
综上所述,这行代码的作用是按比例缩放检测到的边界框坐标,使其与原始图像的尺寸相匹配,并四舍五入结果以得到整数像素坐标,这些坐标随后可用来在原始图像上绘制边界框。
# Print results
for c in det[:, 5].unique():
n = (det[:, 5] == c).sum() # detections per class
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
这段代码的作用是从检测结果中打印每个类别的检测数量。让我们逐步解读这个循环:
- for c in det[:, 5].unique():
- det 是一个二维数组,其中每行代表一个检测结果。
- det[:, 5] 获取 det 数组的第六列,即每个检测结果的类别标识符(因为 Python 索引从0开始)。
- .unique() 方法用来获取第六列中所有不重复的元素,即已检测到的唯一类别标识符。
- n = (det[:, 5] == c).sum()
- 这行代码创建了一个布尔数组,数组中的每个元素对应 det 第六列中的一个元素,用来表示该元素是否等于当前类别 c。
- .sum() 方法对布尔数组中的 True 值(表示为 1)求和,从而得到属于类别 c 的检测数量。
- s += f"{n} {names[int©]}{‘s’ * (n > 1)}, "
- 这行代码构建了一个字符串,用于表示当前类别的检测数量和类别名称。
- names[int©] 从 names 列表中获取类别 c 对应的名称,int© 确保了索引是整数。
- {‘s’ * (n > 1)} 这部分根据检测数量决定是否需要在类别名后面添加 ‘s’ 来表示复数。如果 n 大于 1,则添加 ‘s’。
- 最后,这个构建好的字符串被添加(concatenated)到 s 上,用 , 分隔各个类别的信息。
整体来看,这段代码用于迭代每个唯一的检测到的类别,并统计每个类别的检测数量,然后将每个类别和它的数量以文本形式添加到字符串 s 中,用于后续打印或显示。
# Write results
for *xyxy, conf, cls in reversed(det):
if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
with open(f'{txt_path}.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
if save_img or save_crop or view_img: # Add bbox to image
c = int(cls) # integer class
label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
annotator.box_label(xyxy, label, color=colors(c, True))
if save_crop:
save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
- 遍历检测结果:
for *xyxy, conf, cls in reversed(det):
- 对 det(检测结果数组,其中每一行包含一个边界框的坐标 xyxy、置信度 conf 和类别 cls)进行逆序迭代。逆序可能是为了在绘制边界框时使某些框在前面显示。
- 保存检测结果到文本文件:
if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
with open(f'{txt_path}.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
- 如果 save_txt 标志为 True,则将每个检测结果转换为相应的文本格式并保存到 .txt 文件中。
- xyxy2xywh 函数将边界框从 (x1, y1, x2, y2) 格式转换为 (x_center, y_center, width, height) 格式。
- 归一化坐标是通过除以 gn(归一化增益)来执行的,这个增益是由原始图像的宽度和高度构成的。
- 最后,将格式化的文本行写入到文件中,每个检测结果一行。
- 将边界框的坐标从 (x1, y1, x2, y2) 转换为 (x_center, y_center, width, height) 格式,并进行标准化:
1xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
- xyxy2xywh 函数接受一个边界框的坐标 xyxy 作为输入,并将其转换为 xywh 格式(即中心点坐标和宽高)。
- torch.tensor(xyxy).view(1, 4) 创建一个包含边界框坐标的 PyTorch 张量,并将其重塑为 1x4 的形状。
- 除以 gn (归一化增益,一个包含原始图像宽高的张量)实现对坐标的标准化。标准化是将坐标值转换为相对于原始图像宽度和高度的比例。
- .view(-1).tolist() 将张量展平并转换为 Python 列表。
- 构造要写入文件的行数据:
1line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
- 根据 save_conf 变量的值,确定是否包含置信度 conf 在内。
- 如果 save_conf 为 True,则在行数据中包含类别 cls、标准化的边界框坐标 xywh 和置信度 conf。
- 如果 save_conf 为 False,则只包含类别 cls 和标准化的边界框坐标 xywh。
- 打开文本文件并追加行数据:
1with open(f'{txt_path}.txt', 'a') as f:
2 f.write(('%g ' * len(line)).rstrip() % line + '\n')
- 打开文件 {txt_path}.txt 用于追加写入(模式 ‘a’ 表示 append)。
- 使用格式化字符串 ('%g ’ * len(line)).rstrip() 创建一个格式化模板。其中 %g 是一个占位符,用于表示浮点数或整数,并且会移除无效的小数点和尾随的零。* len(line) 表示重复这个占位符与 line 列表的长度一致次数,最后使用 rstrip() 移除字符串末尾的空格。
- 使用 % 运算符将 line 列表中的值填充到格式化模板中。
- 最后,将这个格式化后的字符串写入文件,并在末尾添加换行符 \n。
综上所述,这段代码将检测框的信息转换为一种格式化的文本表示,包括类别、标准化的边界框坐标和(可选的)置信度,并将其保存到一个文本文件中。这样的文本文件通常用于目标检测任务中的标签或注释文件。
- 在图像上添加边界框和标签:
1if save_img or save_crop or view_img: # Add bbox to image
2 c = int(cls) # integer class
3 label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
4 annotator.box_label(xyxy, label, color=colors(c, True))
- 如果设置了 save_img、save_crop 或 view_img 中的任一标志,则执行图像标注。
- label 变量根据 hide_labels 和 hide_conf 标志来决定是否应该显示标签和置信度。
- 使用 annotator 对象(一个用于图像标注的类实例)将边界框和标签添加到图像上。
- 保存裁剪的检测框图像:
1if save_crop:
2 save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
- 如果 save_crop 标志为 True,则单独保存每个检测到的对象的图像裁剪部分。
- save_one_box 函数从图像 imc 中裁剪出边界框 xyxy 指定的区域,并将裁剪的图像保存到指定的路径。
总结来说,这段代码在处理完检测结果后,根据配置的选项将检测结果保存到文本文件、在原始图像上绘制边界框和标签,并且可选地保存每个检测对象的图像裁剪。
# Stream results
im0 = annotator.result()
if view_img:
if platform.system() == 'Linux' and p not in windows:
windows.append(p)
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
cv2.imshow(str(p), im0)
cv2.waitKey(1) # 1 millisecond
这段代码的作用是在屏幕上实时显示目标检测的结果。它使用了 cv2(OpenCV库)函数来创建窗口和显示图像。下面是详细的步骤解读:
- 获取标注后的图像:
- 如果 view_img 标志设置为 True,则显示图像:
- 对于 Linux 系统,创建新的显示窗口:
- 在窗口中显示图像:
- 等待 1 毫秒以刷新显示:
1im0 = annotator.result()
- annotator 是一个用于在图像上添加注释的对象(例如,在目标周围画边界框、添加类别标签等)。
- 调用该对象的 result() 方法以获取标注完成后的图像,这个图像已经包含了所有的检测结果的可视化信息。
1if view_img:
- 仅在需要显示图像时(例如,进行实时监控或检查模型的性能时),view_img 变量会被设置为 True。
1if platform.system() == 'Linux' and p not in windows:
2 windows.append(p)
3 cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
4 cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
- 检查当前运行平台是否为 Linux,并且路径 p(通常指图像或视频帧的路径)不在已创建的窗口列表 windows 中。
- 如果是 Linux 系统,并且窗口尚未创建,则创建一个新的命名窗口,允许用户调整窗口大小,并保持图像的长宽比。
- 将窗口的大小调整为图像 im0 的大小。
1cv2.imshow(str(p), im0)
- 使用 cv2.imshow() 函数显示图像。第一个参数是窗口名称,这里使用图像路径 p 转换为字符串作为窗口名称,第二个参数是要显示的图像。
1cv2.waitKey(1) # 1 millisecond
- 使用 cv2.waitKey() 函数等待一个很短的时间(1毫秒),这允许图像在窗口中更新,并且可以让用户通过按键来交互(例如,按键退出显示)。
总的来说,这段代码是用来将处理完毕的图像实时显示到屏幕上,以便用户可以实时查看目标检测模型的输出结果。在 Linux 系统上,代码还处理了窗口的创建和尺寸调整,以方便用户观察。
# Save results (image with detections)
if save_img:
if dataset.mode == 'image':
cv2.imwrite(save_path, im0)
else: # 'video' or 'stream'
if vid_path[i] != save_path: # new video
vid_path[i] = save_path
if isinstance(vid_writer[i], cv2.VideoWriter):
vid_writer[i].release() # release previous video writer
if vid_cap: # video
fps = vid_cap.get(cv2.CAP_PROP_FPS)
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else: # stream
fps, w, h = 30, im0.shape[1], im0.shape[0]
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
vid_writer[i].write(im0)
这段代码的目的是保存经过检测模型处理并添加了检测框的图像或视频帧。它根据输入数据是单幅图像还是视频流进行不同的操作:
- 检查是否有保存图像的标志:
1if save_img:
- 如果 save_img 设置为 True,则执行保存操作。
- 保存单幅图像:
1if dataset.mode == 'image':
2 cv2.imwrite(save_path, im0)
- 如果数据集模式设置为 ‘image’,则表示正在处理单个图像。使用 cv2.imwrite 函数将处理后的图像 im0 保存到指定的 save_path 路径。
- 保存视频帧:
1else: # 'video' or 'stream'
2 if vid_path[i] != save_path: # new video
3 ...
- 如果数据集模式不是 ‘image’,则默认处理的是视频或视频流。接下来的操作是为了将检测帧保存为视频文件。
- 处理视频文件保存路径:
1if vid_path[i] != save_path:
2 vid_path[i] = save_path
3 ...
- 检查 vid_path[i] 是否与 save_path 相同。如果它们不同,意味着开始处理新视频,需要进行一系列初始化操作。
- 释放上一个 VideoWriter 对象:
1if isinstance(vid_writer[i], cv2.VideoWriter):
2 vid_writer[i].release()
- 检查 vid_writer[i] 是否是 cv2.VideoWriter 的实例,如果是,则释放它。这通常在开始写入新的视频文件之前进行。
- 设置视频写入参数:
1if vid_cap:
2 fps = vid_cap.get(cv2.CAP_PROP_FPS)
3 w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
4 h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
5else:
6 fps, w, h = 30, im0.shape[1], im0.shape[0]
- 获取视频帧的属性,如帧率 fps,宽度 w 和高度 h。如果不是从视频捕获设备读取 (vid_cap 为空),则设置默认值。
- 创建 VideoWriter 对象:
1save_path = str(Path(save_path).with_suffix('.mp4'))
2vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
- 将保存路径的文件扩展名强制设置为 .mp4。
- 创建 cv2.VideoWriter 对象以写入视频文件,使用 ‘mp4v’ 编解码器,以及之前获取的帧率和帧尺寸。
- 写入视频帧:
1vid_writer[i].write(im0)
- 将处理后的帧 im0 写入视频文件。
总体来说,这段代码检查是否需要保存结果图像或视频,对于图像直接保存文件,对于视频则创建或更新 VideoWriter 对象,并将处理后的帧写入视频文件。这样可以得到一个包含检测结果的视频,用于后续的查看或分析。
LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
- LOGGER.info(…)
- LOGGER 是一个配置好的日志记录器对象,用于向控制台、文件或其他日志处理设施输出日志消息。
- info 方法用于记录一条信息级别的日志,表示重要的事件,比如程序的正常操作。
- f"{s}{‘’ if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms"
- 这是一个格式化字符串(f-string),用于构建需要记录的消息文本。
- {s} 会被替换为之前构建的字符串 s,它包含了关于每个类别的检测数量的信息。
- {‘’ if len(det) else '(no detections), '} 用于确定是否有检测到对象。如果 det(即检测结果列表)的长度为非零,那么将不添加额外的文本;如果长度为零,意味着没有检测到任何对象,将添加文本 (no detections), 。
- {dt[1].dt * 1E3:.1f}ms 用于显示处理一帧所需的时间(以毫秒为单位)。dt[1].dt 代表检测耗时的某个时间差对象,* 1E3 是将其从秒转换成毫秒,:.1f 表示格式化为一位小数的浮点数。
整个日志消息构建的目的是提供一条包括检测数量、是否有检测到对象以及检测耗时的综合信息,以便于理解模型在当前帧上的表现和速度。
性能评测
# Print results
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
if save_txt or save_img:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
if update:
strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning)