一、Yolo部分代码
我这边是自己封装的一个检测图片的detect.py
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Run inference on images, videos, directories, streams, etc.
Usage:
$ python path/to/detect_light.py --source path/to/img.jpg --weights yolov5s.pt --img 640
"""
import argparse
import os
import sys
from pathlib import Path
import math
import cv2
import numpy as np
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch import device
import time
import datetime
import os.path
import random
import torch
from flask import Flask, render_template, request
from PIL import Image
app = Flask(__name__)
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
from models.experimental import attempt_load
from utils.datasets import LoadImages, LoadStreams
from utils.general import apply_classifier, check_img_size, check_imshow, check_requirements, check_suffix, colorstr, \
increment_path, non_max_suppression, print_args, save_one_box, scale_coords, set_logging, \
strip_optimizer, xyxy2xywh
from utils.plots import Annotator, colors
from utils.torch_utils import load_classifier, select_device, time_sync
class Main:
def __init__(self, config, device="cuda"):
self.start_time = time.time()
self.config = config
self.weights = self.config.get("weights")
self.source = self.config.get("source")
self.imgsz = self.config.get("imgsz")
self.conf_thres = self.config.get("conf_thres")
self.iou_thres = self.config.get("iou_thres")
self.agnostic_nms = self.config.get("agnostic_nms", False)
self.classes = self.config.get("classes", None)
self.save_img = self.config.get("save_img", True)
self.view_img = self.config.get("view_img", False)
self.save_crop = self.config.get("save_crop", True)
self.save_txt = self.config.get("save_txt", False)
self.save_conf = self.config.get("save_conf", False)
self.if_save_dir = self.config.get("if_save_dir", True)
self.hide_conf = self.config.get("hide_conf", False)
self.hide_labels = self.config.get("hide_labels", False)
self.visualize = self.config.get("visualize", False)
self.agnostic_nms = self.config.get("agnostic_nms", False)
self.save_dir = ROOT / 'static/images'
self.gen_save_dir()
self.device: device = select_device(self.config.get("device", "cpu"))
self.half = self.device.type != 'cpu' # half precision only supported on CUDA
self.w = str(self.weights[0] if isinstance(self.weights, list) else self.weights)
self.model = torch.jit.load(self.w) if 'torchscript' in self.w else attempt_load(self.weights,
map_location=self.device)
if self.half:
self.model.half()
self.dataset = LoadImages(self.source, img_size=self.imgsz)
# Get names and colors
self.names = self.model.module.names if hasattr(
self.model, 'module') else self.model.names
self.colors = [[random.randint(0, 255) for _ in range(3)]
for _ in range(len(self.names))]
def gen_save_dir(self):
self.save_dir = increment_path(Path(self.save_dir), exist_ok=True) # increment run
def do_inference(self, img, path):
img = torch.from_numpy(img).to(self.device)
img = img.half() if self.half else img.float() # uint8 to fp16/32
img = img / 255.0 # 0 - 255 to 0.0 - 1.0
if len(img.shape) == 3:
img = img[None] # expand for batch dim
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.visualize else False
pred = self.model(img, augment=False, visualize=visualize)[0]
# NMS
pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, self.classes, self.agnostic_nms, max_det=1000)
return pred, img
def process_result(self, img, im0s, path):
pred, img = self.do_inference(img, path)
seen = 0
res_list = []
for i, det in enumerate(pred): # per image
seen += 1
p, s, im0, frame = path, '', im0s.copy(), getattr(path, 'frame', 0)
p = Path(p) # to Path
name = p.name.replace("o", "p")
save_path = str(self.save_dir / name) # img.jpg
s += '%gx%g ' % img.shape[2:] # print string
imc = im0.copy() if self.save_crop else im0 # for save_crop
annotator = Annotator(im0, line_width=3, example=str(self.names))
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Print results
for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() # detections per class
s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
for *xyxy, conf, cls in reversed(det):
now = time.time()
time_use = round(now - self.start_time, 2)
xywh = (xyxy2xywh(torch.tensor(xyxy).view(
1, 4)) / gn).view(-1).tolist() # normalized xywh
box_width = round(xywh[2] * im0.shape[1], 2)
box_height = round(xywh[3] * im0.shape[0], 2)
res_list.append([time_use, box_width, box_height])
if self.save_img or self.save_crop or self.view_img: # Add bbox to image
c = int(cls) # integer class
label = None if self.hide_labels else (
self.names[c] if self.hide_conf else f'{self.names[c]} {conf:.2f}')
# print(label)
annotator.box_label(xyxy, label, color=colors(c, True))
if self.save_crop:
save_one_box(xyxy, imc, file=self.save_dir / 'crops' / self.names[c] / f'{p.stem}.jpg',
BGR=True)
# Save results (image with detections)
if self.save_img:
if self.dataset.mode == 'image':
cv2.imwrite(save_path, im0)
return res_list
@torch.no_grad()
def run(self):
img = torch.zeros((1, 3, self.imgsz, self.imgsz),
device=self.device) # init img
_ = self.model(
img.half() if self.half else img) if device.type != 'cpu' else None # run once
reslist = []
for path, img, im0s, vid_cap in self.dataset:
# self.process_result(img, im0s, path)
try:
data = self.process_result(img, im0s, path)
reslist.append(data)
except Exception as e:
print("报错:", e)
end_time = time.time()
参数可以看到是以JSON格式传进来的。基本三只有几步
gen_save_dir生成保存文件夹 do_inference是图片预处理 process_result是处理锚框和标签
二、HTML部分代码
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>yolo检测</title>
<style>
div{
margin: 0 auto;
text-align: center;
}
.flow {
height: 50px;
background: linear-gradient(to right, red, orange, yellow, green, cyan, blue, purple);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
animation: hue 3s linear infinite;
font-size: 50px;
padding: 20px;
}
@keyframes hue {
0% {
filter: hue-rotate(0deg);
}
100% {
filter: hue-rotate(360deg);
}
}
input,p{
font-size:20px
}
</style>
</head>
<body>
<h1 class="flow" align="center">上传图片</h1>
<form method="post" enctype="multipart/form-data">
<div style="margin-bottom:20px;">
<input type="file" name="file" value="选择图片">
<input type="submit" value="提交">
</div>
</form>
<div>
<div class="img_box" style="display: inline-block">
<p style="color:blue">待检测图片</p>
<img src={{ o_img_path }} width="300px"/>
</div>
<div class="img_box" style="display: inline-block">
<p style="color:green">检测图片</p>
<img src={{ p_img_path }} width="300px"/>
</div>
</div>
</body>
</html>
匹配好o_img_path和p_img_path用flask传进来
三、Flask部分代码
@app.route('/detect', methods=["POST", "GET"])
def detect():
# 如果post请求,也就是前端form提交表单,执行下面代码后返回index.html,get请求就直接返回index.html
if request.method == "POST":
# 1、生成一个随机的文件名:唯一
nowTime = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
randomNum = random.randint(0, 100)
if randomNum <= 10:
randomNum = str(0) + str(randomNum)
uniqueNum = str(nowTime) + str(randomNum)
# 2、获取图像
f = request.files['file'] # 获取前端提交的文件
basepath = "/static/images/"
path = os.path.dirname(__file__) #
suffix = f.filename.split(".")[1] # 获取.jpg
# 3、保存原图像
o_img_path = "." + basepath + uniqueNum + "_o" + "." + suffix
f.save(o_img_path) # f.save只能保存绝对路径
config = {
"weights": ROOT / "best.pt",
"source": path + basepath,
"imgsz": 640,
"conf_thres": 0.45,
"iou_thres": 0.4,
"classes": [0, 1], # 需要检测的类别列表:按照yaml文件顺序
"save_img": True,
"view_img": False,
"if_save_dir": True,
"save_crop": False,
"save_txt": False,
"save_conf": False,
"hide_conf": False,
"hide_labels": False,
"visualize": False,
"agnostic_nms": False
}
# 处理图像
main = Main(config)
main.run()
p_img_path = "." + basepath + uniqueNum + "_p" + "." + suffix
# 删除之前的记录
img_list = os.listdir("." + basepath)
for path in img_list:
path = "." + basepath + path
if path == p_img_path or path == o_img_path:
continue
else:
os.remove(path)
return render_template('index.html', o_img_path=o_img_path, p_img_path=p_img_path)
return render_template('index.html')
if __name__ == "__main__":
app.run()
这个代码我是直接写在YOLO代码块的后面的,YOLO的参数也是从FLASK传过去的,可以进行别的参数链接前端,不用全部写死。
四、文件路径
运行文件名:detect_locate_flask.py;static/image是下载图片以及存储yolo推理后图片的路径;templates是前端页面路径
运行结果展示:
主要是参考这位博主的内容,自己走了一遍: