引言
yolov5官方的detect.py文件集成度比较高。一般情况下我们不需要同时实现这么多的功能,比如:检测图像、视频、摄像头等,包括支持pytorch、tensorflow等模型框架。
本节内容是基于yolov5-6.0的detect.py文件改写的。我相信也适用于相近的yolov5版本。
本节代码功能支撑:windows系统、pytorch模型框架、检测视频文件、基于python语言
使用方法:直接用本节代码替换detect.py文件或者在官方detect.py同级目录下新建一个py文件,配置好环境后直接运行。
代码文件.py
import cv2
import numpy as np
import torch
import os
from models.experimental import attempt_load
from utils.general import check_img_size, non_max_suppression, scale_coords
from utils.plots import Annotator, colors
from utils.augmentations import letterbox
from utils.torch_utils import select_device
# 设置相关参数和路径
img_size = 640
stride = 32
weights = 'yolov5s.pt' # 模型权重文件路径
device = '0' # 设置设备类型
source = 'data/images/3.mp4' # 输入图像路径(也可以是绝对路径)
save_path = 'run1/' # 输出图像保存路径(也可以是绝对路径)
view_img = True # 是否显示检测结果的图像窗口
half = False
# 选择设备(GPU 或 CPU)
device = select_device(device)
# 判断是否可用半精度浮点数运算
half &= device.type != 'cpu'
# 导入模型
model = attempt_load(weights, map_location=device)
# 检查并调整图像大小
img_size = check_img_size(img_size, s=stride)
# 获取模型标签名称
names = model.names
# 读取视频对象
cap = cv2.VideoCapture(source)
frame = 0 # 开始处理的帧数
frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 待处理的总帧数
# 获取当前视频的帧率与宽高,设置同样的格式,以确保相同帧率与宽高的视频输出
fps = cap.get(cv2.CAP_PROP_FPS)
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# 设置视频保存路径
save_path += os.path.basename(source)
# 创建视频编码器
vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
# 循环处理视频帧
while frame <= frames:
# 读取帧图像
ret_val, img0 = cap.read()
if not ret_val:
break
frame += 1
print(f'video {frame}/{frames} {save_path}')
# 调整图像大小
img = letterbox(img0, img_size, stride=stride, auto=True)[0]
# 转换为模型所需的格式
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).to(device)
img = img.float() / 255.0 # 0 - 255 to 0.0 - 1.0
img = img[None] # [h w c] -> [1 h w c]
# 进行目标检测
pred = model(img)[0]
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45, max_det=1000)
# 绘制边界框和标签
det = pred[0]
annotator = Annotator(img0.copy(), line_width=3, example=str(names))
if len(det):
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img0.shape).round()
for *xyxy, conf, cls in reversed(det):
c = int(cls) # integer class
label = f'{names[c]} {conf:.2f}'
annotator.box_label(xyxy, label, color=colors(c, True))
# 将带有边界框的图像写入视频
im0 = annotator.result()
vid_writer.write(im0)
if view_img:
# 调整图像大小并显示
im0 = cv2.resize(im0, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_CUBIC)
cv2.imshow(str('image'), im0)
cv2.waitKey(1)
# 释放资源
vid_writer.release()
cap.release()
print(f'{source} finish, save to {save_path}')
结语
大家一起努力,一步一个脚印,加油!!!