python实现多类别的nms过程

import numpy as np


def nms_multi_cls(boxes, bbox_threshold,overlap_threshold):
    '''

    :param boxes: [bbox_score,xmin,ymin,xmax,ymax,c1_score,c2_score,...,cn_score]
    :return:
    '''
    # 如果没有检测到任何框,返回一个空列表
    if len(boxes) == 0:
        return []
        boxes

    ## 筛选出置信度较高的框
    boxes = boxes[boxes[:,0] > bbox_threshold]
    # 初始化保留框列表和分数列表
    result = []
    ## 截选出类别概率
    cls_scores = boxes[:, 5:]  ##
    max_cls_index = np.argmax(cls_scores, axis=-1)  ## 最大的类别分数的索引
    max_cls_score = np.max(cls_scores, axis=-1)  ## 最大类别分数
    ## detection [bbox_score,xmin,ymin,xmax,ymax,max_cls_score,max_cls]
    detections = np.concatenate([boxes[:, :5], max_cls_score[:, np.newaxis], max_cls_index[:, np.newaxis]], axis=-1)
    detections[:, 0] = detections[:, 0] * detections[:, 5]
    detections = detections[detections[:, 0] > bbox_threshold]
    uniq_cls = np.unique(max_cls_index) ## 检测结果存在的类别

    for c in uniq_cls: ## 遍历每个类别,进行nms操作
        det = detections[detections[:, -1] == c]
        dets = nms_pi(det, thresh=overlap_threshold)
        # Add max detections to outputs
        if len(dets):
            result.append(dets)
    if len(result):
        result = np.concatenate(result, axis=0)
    else:
        return []
    return result


def nms_pi(dets, thresh=0.25):
    """
    refer to:
    https://github.com/facebookresearch/Detectron/blob/main/detectron/utils/cython_nms.pyx
    Apply classic DPM-style greedy NMS.
    """
    if dets.shape[0] == 0:
        return dets[[], :]
    scores = dets[:, 0]
    x1 = dets[:, 1]
    y1 = dets[:, 2]
    x2 = dets[:, 3]
    y2 = dets[:, 4]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1) ## 计算左右框的面积
    order = scores.argsort()[::-1] ## 按类别置信度由大到小排序

    ndets = dets.shape[0]
    suppressed = np.zeros((ndets), dtype=np.int) ## 用来记录检测框是否被抑制了

    for _i in range(ndets): 
        i = order[_i] ## order,按检测框置信度由大到小取出检测它的索引
        if suppressed[i] == 1: 
            continue
        ix1 = x1[i]
        iy1 = y1[i]
        ix2 = x2[i]
        iy2 = y2[i]
        iarea = areas[i] ## 取出面积
        for _j in range(_i + 1, ndets): ## 依次遍历剩下的框, 计算与当前最大的置信度的框之间的iou
            j = order[_j]
            if suppressed[j] == 1:
                continue
            xx1 = max(ix1, x1[j])
            yy1 = max(iy1, y1[j])
            xx2 = min(ix2, x2[j])
            yy2 = min(iy2, y2[j])
            w = max(0.0, xx2 - xx1 + 1)
            h = max(0.0, yy2 - yy1 + 1)
            inter = w * h
            ovr = inter / (iarea + areas[j] - inter)

            if ovr >= thresh:
                suppressed[j] = 1
    keep = np.where(suppressed == 0)[0]
    dets = dets[keep, :]
    ### score,xmin,ymin,xmax,ymax,?
    return dets


if __name__ == '__main__':
    boxes = np.zeros((17, 10))
    ## [bbox_score,xmin,ymin,xmax,ymax,c1_score,c2_score,...,cn_score]
    bbox_threshold = 0.25
    overlap_threshold = 0.25
    nms_multi_cls(boxes, bbox_threshold,overlap_threshold)

nms是抑制掉同类别间重合度较高的框

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值