当前AI领域用到的最普遍的两个深度学习框架是tensorflow和pytorch,
本文是基于pytorch框架搭建的yolov5深度学习算法,主要介绍了代码移植和调试的过程,
开发环境是在ubuntu(Linux)下进行的,Windows下类似,python3.6+,pytorch1.5+,
基本准备
首先我们需要访问github仓库获取最新的yolov5源码:
git clone https://github.com/ultralytics/yolov5.git
进入项目根目录
cd yolov5
项目最外层目录结构如下:
yolov5/
├── data
├── detect.py
├── Dockerfile
├── hubconf.py
├── LICENSE
├── models
├── README.md
├── requirements.txt # 需要安装的依赖库
├── test.py
├── train.py
├── tutorial.ipynb
├── utils
└── weights
一、预安装环境部分
安装依赖包(如果有一定的python基础,或者之前安装过相关的依赖库,则可以省略此步骤):
sudo python3 -m pip install -r requirements.txt
修改模型下载脚本:
sudo vi weights/download_weights.sh
将python改为python3,修改之后的内容如下:
python3 - <<EOF
from utils.google_utils import attempt_download
for x in ['s', 'm', 'l', 'x']:
attempt_download(f'yolov5{x}.pt')
EOF
在项目根目录下运行模型下载脚本:
sudo bash weights/download_weights.sh
下载过程:
Downloading https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5s.pt to yolov5s.pt...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 14.1M/14.1M [00:01<00:00, 10.3MB/s]
Downloading https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5m.pt to yolov5m.pt...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 41.1M/41.1M [00:40<00:00, 1.06MB/s]
Downloading https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5l.pt to yolov5l.pt...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 90.2M/90.2M [02:01<00:00, 775kB/s]
Downloading https://github.com/ultralytics/yolov5/releases/download/v5.0/yolov5x.pt to yolov5x.pt...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 168M/168M [03:21<00:00, 876kB/s]
等待下载完成之后,即可在根目录下看到四个不同的大小的模型,这几个模型的复杂度依次增加,对GPU的浮点运算力要求也会变高:
yolov5s.pt 15M
yolov5m.pt 42M
yolov5l.pt 92M
yolov5x.pt 171M
通过对照图表各个复杂度的YOLOv5模型,可以看出:
YOLOv5x的准确率是比较高的,不过耗时比较长,
YOLOv5s的效果相对较差,但是速度更快,
在实际生产环境下可根据自己的机器配置和项目需求进行模型的选择:
官方测试推理脚本的使用:
图片检测推理
sudo python3 detect.py --source /image/path --weights yolov5s.pt
测试结果如下:
视频检测推理
sudo python3 detect.py --source /video/path --weights yolov5s.pt
运行结束后,如果检测到了物体,图片或视频会保存到:inference/output。
二、自定义测试脚本部分
接下来在项目根目录下根据检测脚本detect.py来编写我们的脚本my_test.py:
vi my_test.py
0x01.导入依赖的包:
import cv2
import torch
import numpy as np
from utils.datasets import letterbox
from models.experimental import attempt_load
from utils.general import (
check_img_size, non_max_suppression, apply_classifier, scale_coords,
xyxy2xywh, plot_one_box, strip_optimizer, set_logging)
0x02.基本参数初始化,显卡和cpu的选择、网络模型、非极大抑制阈值、置信度阈值的初始化
imgsz = 640
video_path = "/video/path"
weights = "yolov5x.pt"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
augment = ""
BAGGAGE_TAGS = ['backpack', 'handbag', 'umbrella', 'cell_phone', 'bicycle', 'suitcase', 'car']
# conf_thres = 0.25
# iou_thres = 0.45
conf_thres = 0.150
iou_thres = 0.071
class_path = "data/coco.names"
classes = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife',
'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair dryer', 'toothbrush']
agnostic_nms = ""
model = attempt_load(weights, map_location=device)
model.half()
0x03.核心检测函数,包含标注文字和标注框等:
def detect_one_pic(cv_img: np.array):
ori_w = cv_img.shape[1]
ori_h = cv_img.shape[0]
img, ratio, (dw, dh) = letterbox(cv_img, new_shape=imgsz)
# Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).to(device)
img = img.half() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
pred = model(img, augment=augment)[0]
# Apply NMS
detections = non_max_suppression(pred, conf_thres, iou_thres, classes="", agnostic=agnostic_nms)[0]
# 适配不同的图像尺寸的缩放比
w_ratio = ori_w / imgsz
if ori_h == 960:
h_ratio = ori_h / (imgsz - 130) # 1280x960
elif ori_h == 1080:
h_ratio = ori_h / (imgsz - 260) # 1920x1080
else:
h_ratio = ori_h / (imgsz - 130) # 1280x960
tag_list = []
if isinstance(detections, torch.Tensor):
for x1, y1, x2, y2, conf, cls_pred in detections:
x1, y1, x2, y2 = int(x1 * w_ratio), int(y1 * h_ratio), int(x2 * w_ratio), int(y2 * h_ratio)
# print(x1, y1, x2, y2, conf, cls_pred)
cls_conf = conf * 100
tag = classes[int(cls_pred)]
tag_list.append(tag)
if tag in BAGGAGE_TAGS:
cv2.rectangle(cv_img, (x1, y1),
(x2, y2),
(0, 0, 255),
2)
cv2.putText(cv_img, "{} {}".format(tag, round(float(cls_conf), 5)), (x1 + 5, y1 + 12),
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(255, 255, 255),
2)
elif tag == 'person':
cv2.rectangle(cv_img, (x1, y1),
(x2, y2),
(0, 255, 0),
2)
cv2.putText(cv_img, "{} {}".format(tag, round(float(cls_conf), 5)), (x1 + 5, y1 + 12),
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(255, 255, 255),
2)
return cv_img, tag_list
else:
return np.array([]), tag_list
图片检测函数
def test_pic():
img = cv2.imread('/image/path')
cv_img, tags = detect_one_pic(img)
cv2.imshow('', cv_img)
key = cv2.waitKey(0)
cv2.destroyAllWindows()
实时视频流检测函数
def test_video():
cap = cv2.VideoCapture('/video/path')
while 1:
ret, frame = cap.read()
if ret:
cv_img, tags = detect_one_pic(frame)
cv2.imshow('', cv_img)
key = cv2.waitKey(1)
if key == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
入口函数:
if __name__ == '__main__':
# test_pic()
test_video()
0x04.测试自定义脚本:
sudo python3 my_test.py
测试结果会通过opencv窗口显示,按q
键退出
为避免后续代码更新导致不能使用,我fork了一份代码到自己仓库,后续如有变动会继续更新:
https://github.com/geniustesda/yolov5
原文地址:
http://blog.tesda.info/index.php/archives/3/