batch_size = prediction.size(0) # 因为NMS是针对一个图片做的,所以要把batch拆成图片来做
write = False
for ind in range(batch_size):
# max_conf 属于最大概率的类的概率, max_conf_score 最大概率类的编号(Index)
max_conf, max_conf_score = torch.max(image_pred[:,5:5+ num_classes], 1) # 按行取最大值
max_conf = max_conf.float().unsqueeze(1) # 添加一个维度
max_conf_score = max_conf_score.float().unsqueeze(1) # 添加一个维度
seq = (image_pred[:,:5], max_conf, max_conf_score) # seq代表 anchor信息(4个维度)+ 置信信息(1个维度)+ 最大类概率 + 最大类编号
image_pred = torch.cat(seq, 1) # 要把刚刚的seq变成一个整体的tensor 所以要拼接 保持行数不动 列数相加
non_zero_ind = (torch.nonzero(image_pred[:,4])) # 列数从0开始计算,置信如果是0要被去重 留下的是置信度列不为0的行号
try:
image_pred_ = image_pred[non_zero_ind.squeeze(),:].view(-1,7) # 消除维度 去除置信度为0的行
# image_pred_ 代表 去除置信度为0的行后包含anchor(4+1信息)+分类概率和类别(一共是7个列的信息)
except:
continue
if image_pred_.shape[0] == 0:
continue
img_classes = unique(image_pred_[:,-1]) # 根据类别去重得到一共有多少类 img_classes代表一共有多少种类(不是类别数 保存的是类别名字)
for cls in img_classes:
# 筛选属于当前类的数据
cls_mask = image_pred_*(image_pred_[:,-1] == cls).float().unsqueeze(1) # 制作当前类的掩码,用于筛选属于当前类的数据 cls_mask中含有全0的行
class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze() # class_mask_ind 代表哪些行不是全0的行 并降低一个维度
image_pred_class = image_pred_[class_mask_ind].view(-1,7) # 单个选择(像列表那样)那些不为0的数据之后,通过view以7列恢复原来的结构
# image_pred_class 存储了属于当前类的 7 列数据
conf_sort_index = torch.sort(image_pred_class[:,4], descending = True )[1] #排序 [1]是因为sort后返回两个列表[0]代表value列表是其中的值列表 [1]代表indices,返回索引
image_pred_class = image_pred_class[conf_sort_index] # 更新排序好的image_pred_class
idx = image_pred_class.size(0) #Number of detections 当前类有多少个检测框
for i in range(idx):# 第一个检测框是最大的
try:
ious = bbox_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i+1:]) # 当前检测框和之后所有检测框计算IoU 存储在ious中
except ValueError:
break
except IndexError:
break
iou_mask = (ious < nms_conf).float().unsqueeze(1) # 表示因为IoU重合度高原因要被舍弃的检测框的编号
image_pred_class[i+1:] *= iou_mask # 舍弃(变0)
non_zero_ind = torch.nonzero(image_pred_class[:,4]).squeeze() # 选择0(被舍弃的编号)
image_pred_class = image_pred_class[non_zero_ind].view(-1,7) # 删除0 image_pred_class代表和当前检测框去重后的情况
# 因为在做删除操作,所以range(idx)会超出image_pred_class但是没关系 超出索引范围会返回空的tensor
batch_ind = image_pred_class.new(image_pred_class.size(0), 1).fill_(ind) # 标记属于哪个batch
seq = batch_ind, image_pred_class
if not write:
output = torch.cat(seq,1)
write = True
else:
out = torch.cat(seq,1)
output = torch.cat((output,out))
YOLOv3 非极大值抑制(Non-Maximum Suppression, NMS) 代码理解
最新推荐文章于 2024-03-15 10:03:08 发布