YOLO中的非极大值抑制

``
def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
#----------------------------------------------------------#
# 将预测结果的格式转换成左上角右下角的格式。
# prediction [batch_size, num_anchors, 85],num_anchor 为3类共9个的anchor都放在一起
4+1+numclass # 1为是否有目标

    #   获取预测框四个点的坐标
    #----------------------------------------------------------#
    box_corner          = prediction.new(prediction.shape)
    box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
    box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
    box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
    box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
    prediction[:, :, :4] = box_corner[:, :, :4] #将前4个值由中心点宽高的,变换为左上角和右下角坐标
        output = [None for _ in range(len(prediction))]
        for i, image_pred in enumerate(prediction): #对每张图片进行预测[25200.85]
            #----------------------------------------------------------#
            #   对种类预测部分取max。
            #   class_conf  [num_anchors, 1]    种类置信度
            #   class_pred  [num_anchors, 1]    种类
            #   选择最大的那一个,1代表每行的最大值 image_Pred[25500,85], 判断种类取max
            #   截取到中间到从5到85之间的值
            #----------------------------------------------------------#
            t4 = image_pred[:, 5:5 + num_classes]

            #torch.max选取最大的值,因为是二维,返回值,和类别即坐标系
            class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)

            #class conf [25000,1] class_pred [25200,1] ,dim=1代表维度保持一样
            #----------------------------------------------------------#
            # 返回类别
            # 利用置信度进行第一轮筛选
            # 种类置信度*预测框是否包含物体的置信度
            #----------------------------------------------------------#
            #image_pred[:,4]判断是否有物体,有物体conf也要高于0.5
            t5 = image_pred[:, 4] #所有行数保持不变,取出第5维的所有元素
            t6 = class_conf[:, 0]
            # 低于阈值0.5的取False 高于则是True 用置信度*是否有物体
            conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()

            #----------------------------------------------------------#
            #   根据置信度进行预测结果的筛选
            # 将预测的坐标,置信度,和预测的类别取出来,高于0.5的值才可以
            #----------------------------------------------------------#
            image_pred = image_pred[conf_mask]
            class_conf = class_conf[conf_mask]
            class_pred = class_pred[conf_mask]
            if not image_pred.size(0):
                continue
            #-------------------------------------------------------------------------#
            #   detections  [num_anchors, 7]
            #   7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
            #   将目标获取出
            #-------------------------------------------------------------------------#
            detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)

            #------------------------------------------#
            #   获得预测结果中包含的所有种类
            #------------------------------------------#
            unique_labels = detections[:, -1].cpu().unique()

            if prediction.is_cuda:
                unique_labels = unique_labels.cuda()
                detections = detections.cuda()

            for c in unique_labels:
                #------------------------------------------#
                #   获得某一类得分筛选后全部的预测结果
                #------------------------------------------#
                detections_class = detections[detections[:, -1] == c]
                #  获得相同类别得anchor框
                #------------------------------------------#
                #   使用官方自带的非极大抑制会速度更快一些!
                #   筛选出一定区域内,属于同一种类得分最大的框
                #------------------------------------------#

                keep = nms(
                    detections_class[:, :4], #4坐标位置
                    detections_class[:, 4] * detections_class[:, 5], #有物体 * 置信度
                    nms_thres
                ) #判断重合程度是否大于nms
                max_detections = detections_class[keep]
                
                # # 按照存在物体的置信度排序
                # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
                # detections_class = detections_class[conf_sort_index]
                # # 进行非极大抑制
                # max_detections = []
                # while detections_class.size(0):
                #     # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
                #     max_detections.append(detections_class[0].unsqueeze(0))
                #     if len(detections_class) == 1:
                #         break
                #     ious = bbox_iou(max_detections[-1], detections_class[1:])
                #     detections_class = detections_class[1:][ious < nms_thres]
                # # 堆叠
                # max_detections = torch.cat(max_detections).data
                
                # Add max detections to outputs
                output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
            
            if output[i] is not None:
                output[i]           = output[i].cpu().numpy()
                box1 = output[i][:, 0:2]
                box2 = output[i][:, 2:4]
                box_xy, box_wh      = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
                output[i][:, :4]    = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
        return output


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值