YOLOv3 非极大值抑制(Non-Maximum Suppression, NMS) 代码理解

    batch_size = prediction.size(0) # 因为NMS是针对一个图片做的,所以要把batch拆成图片来做
    write = False
    for ind in range(batch_size):
    	# max_conf 属于最大概率的类的概率, max_conf_score 最大概率类的编号(Index)
        max_conf, max_conf_score = torch.max(image_pred[:,5:5+ num_classes], 1) # 按行取最大值
        max_conf = max_conf.float().unsqueeze(1) # 添加一个维度
        max_conf_score = max_conf_score.float().unsqueeze(1) # 添加一个维度
        seq = (image_pred[:,:5], max_conf, max_conf_score) # seq代表 anchor信息(4个维度)+ 置信信息(1个维度)+ 最大类概率 + 最大类编号
        image_pred = torch.cat(seq, 1) # 要把刚刚的seq变成一个整体的tensor 所以要拼接 保持行数不动 列数相加
        non_zero_ind =  (torch.nonzero(image_pred[:,4])) # 列数从0开始计算,置信如果是0要被去重 留下的是置信度列不为0的行号
        try:
            image_pred_ = image_pred[non_zero_ind.squeeze(),:].view(-1,7) # 消除维度 去除置信度为0的行
            # image_pred_ 代表 去除置信度为0的行后包含anchor(4+1信息)+分类概率和类别(一共是7个列的信息)
        except:
            continue
        if image_pred_.shape[0] == 0:
            continue
        img_classes = unique(image_pred_[:,-1]) # 根据类别去重得到一共有多少类 img_classes代表一共有多少种类(不是类别数 保存的是类别名字)
        for cls in img_classes:
        	# 筛选属于当前类的数据
            cls_mask = image_pred_*(image_pred_[:,-1] == cls).float().unsqueeze(1) # 制作当前类的掩码,用于筛选属于当前类的数据 cls_mask中含有全0的行
            class_mask_ind   = torch.nonzero(cls_mask[:,-2]).squeeze() # class_mask_ind 代表哪些行不是全0的行 并降低一个维度
            image_pred_class = image_pred_[class_mask_ind].view(-1,7) # 单个选择(像列表那样)那些不为0的数据之后,通过view以7列恢复原来的结构 
            # image_pred_class 存储了属于当前类的 7 列数据
            conf_sort_index = torch.sort(image_pred_class[:,4], descending = True )[1] #排序 [1]是因为sort后返回两个列表[0]代表value列表是其中的值列表 [1]代表indices,返回索引
            image_pred_class = image_pred_class[conf_sort_index] # 更新排序好的image_pred_class
            idx = image_pred_class.size(0)   #Number of detections 当前类有多少个检测框
            
            for i in range(idx):# 第一个检测框是最大的
                try:
                    ious = bbox_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i+1:]) # 当前检测框和之后所有检测框计算IoU 存储在ious中
                except ValueError:
                    break
            
                except IndexError:
                    break
                iou_mask = (ious < nms_conf).float().unsqueeze(1) # 表示因为IoU重合度高原因要被舍弃的检测框的编号
                image_pred_class[i+1:] *= iou_mask # 舍弃(变0)
                non_zero_ind = torch.nonzero(image_pred_class[:,4]).squeeze() # 选择0(被舍弃的编号)
                image_pred_class = image_pred_class[non_zero_ind].view(-1,7) # 删除0 image_pred_class代表和当前检测框去重后的情况
                # 因为在做删除操作,所以range(idx)会超出image_pred_class但是没关系 超出索引范围会返回空的tensor
            batch_ind = image_pred_class.new(image_pred_class.size(0), 1).fill_(ind) # 标记属于哪个batch
            seq = batch_ind, image_pred_class
            
            if not write:
                output = torch.cat(seq,1)
                write = True
            else:
                out = torch.cat(seq,1)
                output = torch.cat((output,out))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值