yolov5 用onnx推理支持按类别就行NMS

1.背景

目标检测中多类别经常需要同类别间做NMS,不同类别间不做NMS,官方的yolov5用pt推理里面已经实现了一版:

实现代码在utils/general.py

        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS

        agnostic参数 True表示多个类一起计算nms,False表示按照不同的类分别进行计算nms
代码重点是在 '+c’这里的c就是偏移量
x[:, :4]表示box(从二维看第0,1,2,3列)
x[:, 4] 表示分数(从二维看第4列)
x[:, 5:6]表示类IDX(从二维看第5列)
max_wh这里是4096,这样偏移量仅取决于类IDX,并且足够大。

2.基于onnx推理

基于onnx推理的也是通过agnostic=False实现,原理和上面的基于py推理的一样,也是需要偏移量,核心地方如下:


    # Batched NMS
    c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
    boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
    i = nms(boxes, scores, iou_thres)  # NMS

完整代码如下:

"""
检测预处理和后处理相关操作
"""
import time
import cv2
import numpy as np
import logging
import onnxruntime
import copy
import os
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]

def cv2_imread(path):
    img = cv2.imdecode(np.fromfile(path, dtype=np.uint8), 1)
    return img


def cv2_imwrite(image, image_path, type='jpg'):
    cv2.imencode('.{}'.format(type), image)[1].tofile(image_path)

def my_letter_box(img, size=(640, 640)):  #
    h, w, c = img.shape
    r = min(size[0] / h, size[1] / w)
    new_h, new_w = int(h * r), int(w * r)
    top = int((size[0] - new_h) / 2)
    left = int((size[1] - new_w) / 2)

    bottom = size[0] - new_h - top
    right = size[1] - new_w - left
    #img_resize = cv2.resize(img, (new_w, new_h),interpolation=cv2.INTER_CUBIC)
    img_resize = cv2.resize(img, (new_w, new_h))
    img = cv2.copyMakeBorder(img_resize, top, bottom, left, right, borderType=cv2.BORDER_CONSTANT,
                             value=(128, 128, 128))
    return img, r, (left, top)

def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
    # Rescale coords (xyxy) from img1_shape to img0_shape
    if ratio_pad is None:  # calculate from img0_shape
        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
        pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
    else:
        gain = ratio_pad[0][0]
        pad = ratio_pad[1]

    coords[:, [0, 2]] -= pad[0]  # x padding
    coords[:, [1, 3]] -= pad[1]  # y padding
    coords[:, :4] /= gain
    clip_coords(coords, img0_shape)
    return coords


def nms(bboxes, scores, iou_thresh):
    """
    :param bboxes: 检测框列表
    :param scores: 置信度列表
    :param iou_thresh: IOU阈值
    :return:
    """
    x1 = bboxes[:, 0]
    y1 = bboxes[:, 1]
    x2 = bboxes[:, 2]
    y2 = bboxes[:, 3]
    areas = (y2 - y1) * (x2 - x1)

    # 结果列表
    result = []
    index = scores.argsort()[::-1]  # 对检测框按照置信度进行从高到低的排序,并获取索引
    # 下面的操作为了安全,都是对索引处理
    while index.size > 0:
        # 当检测框不为空一直循环
        i = index[0]
        result.append(i)  # 将置信度最高的加入结果列表

        # 计算其他边界框与该边界框的IOU
        x11 = np.maximum(x1[i], x1[index[1:]])
        y11 = np.maximum(y1[i], y1[index[1:]])
        x22 = np.minimum(x2[i], x2[index[1:]])
        y22 = np.minimum(y2[i], y2[index[1:]])
        w = np.maximum(0, x22 - x11 + 1)
        h = np.maximum(0, y22 - y11 + 1)
        overlaps = w * h
        ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
        # 只保留满足IOU阈值的索引
        idx = np.where(ious <= iou_thresh)[0]
        index = index[idx + 1]  # 处理剩余的边框
    # bboxes, scores = bboxes[result], scores[result]
    # return bboxes, scores
    return result


def xyxy2xywh(x):
    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
    y = np.copy(x)
    y[:, 0] = (x[:, 0] + x[:, 2]) / 2  # x center
    y[:, 1] = (x[:, 1] + x[:, 3]) / 2  # y center
    y[:, 2] = x[:, 2] - x[:, 0]  # width
    y[:, 3] = x[:, 3] - x[:, 1]  # height
    return y


def xywh2xyxy(x):
    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = np.copy(x)
    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
    y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
    y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
    return y


def non_max_suppression(prediction,
                        conf_thres=0.25,
                        iou_thres=0.45,
                        classes=None,
                        agnostic=False,
                        multi_label=False,
                        labels=(),
                        max_det=300):
    """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """

    bs = prediction.shape[0]  # batch size
    nc = prediction.shape[2] - 5  # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

    # Settings
    # min_wh = 2  # (pixels) minimum box width and height
    max_wh = 7680  # (pixels) maximum box width and height
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
    time_limit = 0.3 + 0.03 * bs  # seconds to quit after
    redundant = True  # require redundant detections
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS

    t = time.time()
    output = [np.zeros((0, 6))] * bs
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
        x = x[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            lb = labels[xi]
            v = np.zeros((len(lb), nc + 5), device=x.device)
            v[:, :4] = lb[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(lb)), lb[:, 0].long() + 5] = 1.0  # cls
            x = np.concatenate((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue
        # Compute conf
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf
        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
        box = xywh2xyxy(x[:, :4])
        # Detections matrix nx6 (xyxy, conf, cls)
        conf, j = x[:, 5:].max(1, keepdims=True), x[:, 5:].argmax(1)[:, None]
        x = np.concatenate((box, conf, j), 1)[conf.reshape(-1) > conf_thres]
        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        elif n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence

        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = nms(boxes, scores, iou_thres)  # NMS
        if len(i) > max_det:  # limit detections
            i = i[:max_det]
        '''
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy
        '''
        output[xi] = x[i]
        if (time.time() - t) > time_limit:
            logging.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded')
            break  # time limit exceeded
    return output


class yolov5Det():
    def __init__(self, weights, img_size=(640, 640), conf_thres=0.45,
                 iou_thres=0.50, max_det=1000, agnostic_nms=False, device='cpu'):

        self.weights = weights
        self.img_size = img_size
        self.conf_thres = conf_thres
        self.iou_thres = iou_thres
        self.max_det = max_det
        self.agnostic_nms = agnostic_nms
        self.device = device

        providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if self.device != 'cpu' else [
            'CPUExecutionProvider']
        self.session = onnxruntime.InferenceSession(weights, providers=providers)
        self.names = ["a", "b","c","d"]  ##换成自己模型对应的类名即可

    def data_preprocess(self, img0s):
        # Set Dataprocess & Run inference
        img,r,(left,top) = my_letter_box(img0s, size = self.img_size)
        # print("===", img.shape)
        #cv2.imwrite('0414_auto.jpg', img)
        # Convert
        img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
        img = np.ascontiguousarray(img)

        img = img.astype(dtype=np.float32)
        img /= 255  # 0 - 255 to 0.0 - 1.0
        if len(img.shape) == 3:
            img = img[None]  # expand for batch dim
        if len(img.shape) == 3:
            img = img[None]  # expand for batch dim
        return img

    def pred(self, img0s):

        """
        对输入的图片进行目标检测,返回对应的类别的检测框,并可视化输出
        #输出后自己做对应的逻辑处理
        """
        img = self.data_preprocess(img0s)
        # Inference
        pred = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: img})[0]

        #y_onnx = session_detect.run([session_detect.get_outputs()[0].name], {session_detect.get_inputs()[0].name: img})[
        #    0]
        # NMS
        pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, None, self.agnostic_nms, max_det=self.max_det)
        det = pred[0]  # detections single image
        # Process detections
        if len(det):
            # Rescale boxes from img_size to im0 size
            det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img0s.shape).round()
            # Write results
            for *xyxy, conf, cls in reversed(det):
                label = f'{self.names[int(cls)]}'
                cls = int(cls)
                prob = round(float(conf), 2)  # round 2
                # c_x = (int(xyxy[0]) + int(xyxy[2])) / 2
                # c_y = (int(xyxy[1]) + int(xyxy[3])) / 2
                # Img vis
                xmin, ymin, xmax, ymax = xyxy
                newpoints = [(int(xmin), int(ymin)), (int(xmax), int(ymax))]
                self.draw_vis(img0s, newpoints, label, prob, cls)
                print('-----', img0s.shape, xyxy, label)
        return img0s

    def draw_vis(self, img, pts, label, prob,cls):
        # vis draw
        font = cv2.FONT_HERSHEY_SIMPLEX
        newpoints = np.array(pts)
        cv2.rectangle(img, newpoints[0], newpoints[1], colors[cls], 2)
        cv2.putText(img, label + '_' + str(prob), newpoints[0], font, 1, (0, 0, 255), 1, cv2.LINE_AA)

        return img

def allFilePath(rootPath,allFIleList):  #遍历文件
    fileList = os.listdir(rootPath)
    for temp in fileList:
        if os.path.isfile(os.path.join(rootPath,temp)):
            allFIleList.append(os.path.join(rootPath,temp))
        else:
            allFilePath(os.path.join(rootPath,temp),allFIleList)

if __name__ == "__main__":

    import sys
    import glob
    import matplotlib.pyplot as plt

    onnx_weight_path = r"/data/yolov5s.onnx"
    img_pth = r"/data/imgs_ls/"
    file_list = []
    allFilePath(img_pth, file_list)
    cert_material_det = yolov5Det(onnx_weight_path)
    for img_pth in file_list:
        if not img_pth.endswith((".jpg",".png")):
            print(img_pth)
            continue
        img = cv2_imread(img_pth)
        img0 = copy.deepcopy(img)
        out = cert_material_det.pred(img)
        #print('++++', out.shape)
        #cv2.imwrite('04141.jpg', out)
        cv2.imshow("re",out)
        cv2.waitKey(0)

  • 7
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值