1. 保留原始框架结构
2. 自定义detect推理类
"""
python-inference 推理类
"""
import os
import torch
import cv2 as cv
import numpy as np
from models.experimental import attempt_load # 加载模型
from utils.general import non_max_suppression, scale_coords # nms 坐标缩放
from utils.torch_utils import select_device # 加载设别
from utils.augmentations import letterbox
class Detect(object):
def __init__(self, weights, image_w, image_h, class_code, cfg, iou, max_bbox, dev):
"""
构造函数 - 初始化参数
inputs
------
weights 模型权重 str
image_w 图像宽 int
image_h 图像高 int
class_code 类别编码 dict
cfg_ 置信度阈值 float
iou_ 交并比阈值 float
max_bbox_ 最大框维度 int
dev 计算设别ID str
"""
self.W = weights
self.img_w = image_w
self.img_h = image_h
self.class_code = class_code
self.cfg_ = cfg
self.iou_ = iou
self.max_bbox_ = max_bbox
# 读取权重 加载模型
self.w = str(self.W[0] if isinstance(self.W, list) else self.W)
if dev == 'cpu':
self.dev_ = dev
else:
self.dev_ = select_device(dev)
self.model = torch.jit.load(self.w) if 'torchscript' in self.w else attempt_load(self.W, map_location=self.dev_)
def resize_zoom(self, image):
"""
修改图像维度并归一化
"""
image_ = letterbox(image, (640, 640))[0] # 自适应图像缩放
image_ = image_ / 255.0
return image_
@classmethod
def read_data(cls, image_path):
"""
读取图像
"""
image = cv.imread(image_path)
return image, image.shape
def predict_(self, data, dev, IMG):
"""
开始预测
inputs
------
data 修改维度并归一化过后的图像 array
dev 设备ID str
IMG 原始图像 array
outputs
-------
IMG 检测结果 array
"""
img = data
# 转换通道位置 HWC转CHW
img = img[:, :, ::-1].transpose((2, 0, 1))
img = np.expand_dims(img, axis=0)
if dev == 'cpu':
# 矩阵转化tensor
img = torch.from_numpy(img.copy())
else:
# 矩阵转化tensor 并加入cuda设备
img = torch.from_numpy(img.copy()).cuda()
# float64转换float32
img = img.to(torch.float32)
# 开始预测
predict = self.model(img, augment='store_true', visualize='store_true')[0]
# NMS
predict = non_max_suppression(predict, self.cfg_, self.iou_, None, False, max_det=self.max_bbox_)
# 绘制bbox信息
for i, det in enumerate(predict):
if len(det):
# 坐标缩放
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], IMG.shape).round()
for *xyxy, conf, cls in reversed(det):
# 类别筛选
class_and_cfg = self.class_code[int(cls)]# + ' ' + str(conf)
# 绘制bbox
IMG = cv.rectangle(IMG, (int(xyxy[0]), int(xyxy[1])), (int(xyxy[2]), int(xyxy[3])), (0, 0, 255), 3)
# 写入类别置信度
cv.putText(IMG, class_and_cfg, (int(xyxy[0]), int(xyxy[1]) - 5), cv.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2)
return IMG