#NMS
'''
image_pred_ 是一个图片上所有bbox组成的二维varibale
(all_bboxs_num, 4+1+1+1) 4:左上右下坐标,1:检测置信度,1:最大分类概率,1:最大分类概率的类别序号--代表哪一个类
对每一个类进行NMS
'''
for cls in img_classes:
'''
得到该类的所有检测框,剔除非该类的检测框
不为该类的bbox设为0 为该类的设为1
用非0 行索引提取非0 bbox 剔除0 bbox
用nonzero()[] 提取非0 bbox的行索引
'''
cls_mask = image_pred_ * (image_pred_[:,-1] == cls).float().unsqueeze(1)
class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze()
image_pred_class = image_pred_[class_mask_ind].view(-1,7)
'''
2排序--用置信度进行降序排序 sort()[1] 返回行索引,用行索引重排pred
idx 该类的bbox的个数
'''
conf_sort_index = torch.sort(image_pred_class[:,4]), descending = True)[1]
image_pred_class = image_pred_class[conf_sort_index]
idx = imge_pred_class.size(0)
for i in range(idx)
'''
3计算iou
计算pre[i] 与 pre[i+1 : ]的ious
剩余最后一个后就break跳出
'''
try:
ious = bbox_iou(image_pred_class[i], image_pred_class[i+1: ])
except: ValueError:
break
except: IndexError:
break
'''
根据iou剔除大于阈值的bbox
iou 大于阈值的设为0 ,小于阈值的设为1 ,
用nonzero()提取置信度非零行索引
用非零索引提取非零bbox
'''
iou_mask = (ious < nms_conf).float().unsqueeze(1)
image_pred_class[i+1: ] *= iou_mask
non_zero_ind = torch.nonzero(image_pred_class[:,4]).squeeze()
image_pred_class = image_pred_class[non_zero_ind].view(-1,7)
#为该类别NMS挑选后的bbox 前面打上batch的序号
batch_ind = image_pred_class.new(image_pred_class.size(0),1).fill_(ind)
seq = batch_ind, image_pred_class
#cat所有的结果--> output
if not write: #新Tensor值不能为空 初始化
output = torch.cat(seq,1)
write = True
else:
out = torch.cat(seq,1) #在列维度上拼接--横向拼接
output = torch.cat((output,out)) #在行维度上拼接--纵向拼接
#计算IOU
def bbox_iou(bbox1,bbox2):
if coordinate == 'center': #坐标为中心点坐标x,y,w,h
bbox1_[:,0] = bbox[:,0] - bbox[:,2]/2
bbox1_[:,1] = bbox[:,1] - bbox[:,3]/2
bbox1_[:,2] = bbox[:,0] + bbox[:,2]/2
bbox1_[:,1] = bbox[:,1] + bbox[:,3]/2
bbox2_[:,0] = bbox[:,0] - bbox[:,2]/2
bbox2_[:,1] = bbox[:,1] - bbox[:,3]/2
bbox2_[:,2] = bbox[:,0] + bbox[:,2]/2
bbox2_[:,1] = bbox[:,1] + bbox[:,3]/2
b1_x1, b1_y1, b1_x2, b1_y2 = bbox1_[:,0], bbox1_[:,1], bbox1_[:,2], bbox1_[:,3]
b2_x1, b2_y1, b2_x2, b2_y2 = bbox2_[:,0], bbox2_[:,1], bbox2_[:,2], bbox1_[:,3]
if coordinate == 'angle': #坐标为左上右下定点坐标x1,y1,x2,y2
b1_x1, b1_y1, b1_x2, b1_y2 = bbox1[:,0], bbox1[:,1], bbox1[:,2], bbox1[:,3]
b2_x1, b2_y1, b2_x2, b2_y2 = bbox2[:,0], bbox2[:,1], bbox2[:,2], bbox1[:,3]
#计算交集区域面积
inter_x1 = torch.max(b1_x1,b2_x1)
inter_y1 = torch.max(b1_y1,b2_y1)
inter_x2 = torch.max(b1_x2,b2_x2)
inter_y2 = torch.max(b1_y2,b2_y2)
inter_arae = torch.clamp(inter_x2 - inter_x1+1,min=0)*torch.clamp(inter_y2 - intery1 +1 ,min =0)
#计算并集区域面积
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
all_area = b1_area + b2_area - inter_area
iou = inter_area/all_area
return iou