使用OpenCV输出预测结果,简化推理代码以便于部署。
需要YOLOv5的原代码包括:models.experimental,utils.general和utils.torch_utils
代码如下:更改权重和输入图片路径即可
import cv2
import numpy as np
import torch
from models.experimental import attempt_load
from utils.general import non_max_suppression, scale_coords
from utils.torch_utils import select_device
weights = 'yolov5s.pt'
input = 'data/images/bus.jpg'
w = str(weights[0] if isinstance(weights, list) else weights)
device = select_device(0)
model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device) # load the model
height, width = 640, 640 # image size
img0 = cv2.imread(input) # read the input
img = cv2.resize(img0, (height, width)) # resize the image
img = img / 255.
img = img[:, :, ::-1].transpose((2, 0, 1)) # transfer from HWC to CHW
img = np.expand_dims(img, axis=0) # extend the dimension to [1,3,640,640]
img = torch.from_numpy(img.copy()) # transfer from numpy to tensor
img = img.to(torch.float32).to(device) # transfer from float64 to float32
pred = model(img, augment='store_true', visualize='store_true')[0]
pred.clone().detach()
pred = non_max_suppression(pred, 0.25, 0.45, None, False, max_det=1000) # NMS: Non Max Suppression
for i, det in enumerate(pred):
if len(det):
det = det.cpu()
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img0.shape).round()
# xyxy: coords, conf: conference, cls: classification results
for *xyxy, conf, cls in reversed(det):
# transfer the index to class
if cls.numpy() == 0:
cls = 'person'
elif cls.numpy() == 5:
cls = 'bus'
# print the prediction information
print('{},{},{}'.format(xyxy, conf.numpy(), cls))
# draw the bounding box
img0 = cv2.rectangle(img0, (int(xyxy[0].numpy()), int(xyxy[1].numpy())), (int(xyxy[2].numpy()), int(xyxy[3].numpy())), (0, 255, 0), 2)
# draw the class results
font = cv2.FONT_HERSHEY_SIMPLEX
color = (np.random.randint(255), np.random.randint(255), np.random.randint(255))
cv2.putText(img0,text=cls + ": " + str(format(conf.numpy(), '.2f')),org=(int(xyxy[0].numpy()), int(xyxy[1].numpy()) - 10),
fontFace=1,fontScale=1.5,thickness=2, color=(color[1], color[2], color[0]))
# print out the results
cv2.imwrite('out.jpg', img0)
[1] 参考:Link