yolov5 detect预测端封装
import torch
from yolov5.models.common import DetectMultiBackend
from yolov5.utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
import cv2
from PIL import Image
import numpy as np
from torchvision import transforms
from matplotlib import pyplot as plt
from PIL import Image
self.model = DetectMultiBackend(self.weights, device=device, dnn=False, data="", fp16=False)
def predict(self,source,**kwargs):
if self.algoCategory == "v5":
# # 获取图像
# im = cv2.imread(source)
# im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
# im = Image.open(source).convert('RGB')
# import numpy
# im = numpy.array(im)
# im = torch.from_numpy(im).to(self.model.device)
# im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
# im /= 255 # 0 - 255 to 0.0 - 1.0
# if len(im.shape) == 3:
# im = im[None] # expand for batch dim
# imgs = list()
# for i in img_path:
# img = Image.open(os.path.join(vde_path, i))
# transform = transforms.Compose([transforms.Resize([128, 128]), transforms.ToTensor()])
# img = transform(img)
# print(img.shape)
# imgs.append(img)
# imgs = torch.stack(imgs, dim=0)
# print(imgs.shape)
img0 = Image.open(source).convert('RGB')
img = img0.resize((224, 480))
img0 = cv2.cvtColor(np.asarray(img0), cv2.COLOR_RGB2BGR)
img_transform = transforms.Compose([transforms.ToTensor()])
im = img_transform(img)
im = im.to(self.model.device)
# im.unsqueeze_(dim=0) # C*H*W to B*C*H*W
imgs = list()
imgs.append(im)
imgs.append(im)
im = torch.stack(imgs, dim=0)
# 推理
pred = self.model(im, augment=False, visualize=False)
# 极大值抑制
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45, max_det=1000)
# 获取宽高宽高
gn = torch.tensor(img0.shape)[[1, 0, 1, 0]] # normalization gain whwh
# Process predictions
# 多张图像的结果
res_list = []
for i, det in enumerate(pred): # per image
if len(det):
# 单张图像的结果
res = []
# Rescale boxes from img_size to im0 size
# 预测框大小返回到拉伸原图大小
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], img0.shape).round()
# Write results
for *xyxy, conf, cls in reversed(det):
if True: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if True else (cls, *xywh) # label format
box = xyxy
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
cv2.rectangle(img0, p1, p2, (0, 255, 255), thickness=1, lineType=cv2.LINE_AA)
imgshow= Image.fromarray(cv2.cvtColor(img0,cv2.COLOR_BGR2RGB))
imgshow.show()
imgshow.save(str(i).png)
res.append((cls.item(),p1, p2,conf.item()))
res_list.append(res)
return res_list