YOLOV3 NMS笔记
1 参数
假设当前网络输入大小为416*416 ,分类为3个 [a,b,c] ,每个YOLO层3个anchor
输入图像:img=1*3*416*416 (代表这个批次内一张图片,通道为3)
经过降采样最大的YOLO层后(yolov3 有多个yolo层) ,输出
prediction=1*24*13*13,其中13*13 是最后的特征图的大小,24代表(5+3)*3
2 对yolo层的输出prediction做置信度处理和NMS非最大值抑制
(1)对prediction的 shape变换,这样为了更方便的进行非极大值抑制
1*24*13*13 变换为1*507*8 其中8 维的数据代表 [centerx1,centery1,w,h,confidence.c1,c2,c3],c1,c2,c3识别三个类别的得分
其中507是13*13*3(3个anchors)
(2)将confidence<0.5 阈值的候选框移除
conf_mask = (prediction[:,:,4] > confidence).float().unsqueeze(2)
#此处没有移除这些推测框,而是将他们的值全变为0,也可以在此处就把所有的不符合confidence的候选框全部去掉
prediction = prediction*conf_mask
# 将坐标(centerx,centery,w,h)转换为 (x1,y1,x2,y2)的格式
box_corner = prediction.new(prediction.shape)
box_corner[:,:,0] = (prediction[:,:,0] - prediction[:,:,2]/2)
box_corner[:,:,1] = (prediction[:,:,1] - prediction[:,:,3]/2)
box_corner[:,:,2] = (prediction[:,:,0] + prediction[:,:,2]/2)
box_corner[:,:,3] = (prediction[:,:,1] + prediction[:,:,3]/2)
prediction[:,:,:4] = box_corner[:,:,:4]
(3)非极大值NMS抑制
firstflag=False
for ind in range(batch_size): #对批次内的图像逐个处理
image_pred = prediction[ind] #取出一张图片的候推测框
#计算推测的框所属的类别的得分 max_conf_score 和 max_conf_class分别代表类别和得分
#
max_conf_score, max_conf_class = torch.max(image_pred[:,5:5+ num_classes], 1)
max_conf_class = max_conf_class.float().unsqueeze(1)
max_conf_score = max_conf_score.float().unsqueeze(1)
#seq 组合需要的数值会作为返回值
# 7维的数据分别为 x1,y1,x2,y2,confidence,classscore(类别得分),class(所属类)
seq = (image_pred[:,:5],max_conf_score, max_conf_class)
image_pred = torch.cat(seq, 1)
#去除confidence等于0的预测框
non_zero_ind = (torch.nonzero(image_pred[:,4]))
try:
image_pred_ = image_pred[non_zero_ind.squeeze(),:].view(-1,7)
except: #异常说明没有一个检测的正确
continue
#没有检测的合法的物体
if image_pred_.shape[0] == 0:
continue
#获取识别框的中存在的类别,下面要按照类别进行独立的NMS
# image_pred 第二维的最后一个数是类别
img_classes = unique(image_pred_[:,-1])
#对每个类别单独处理,不通类别的IOU即使很大也不能去除任何一个
for cls in img_classes:
#获取这个类别所有的候选框
cls_mask = image_pred_*((image_pred_[:,-1] == cls).float().unsqueeze(1))
class_mask_ind = torch.nonzero(cls_mask[:,-2]).squeeze()
image_pred_class = image_pred_[class_mask_ind].view(-1,7)
#按照类别得分进行倒叙排序
conf_sort_index = torch.sort(image_pred_class[:,4], descending = True )[1]
image_pred_class = image_pred_class[conf_sort_index]
idx = image_pred_class.size(0) #候选框的个数
#这个循环进行非极大值抑制
for i in range(idx):
try:
#计算一个与后面所有的框的IOU
ious=bbox_iou(image_pred_class[i].unsqueeze(0),\
image_pred_class[i+1:])
except Exception as e:#只有一个框会异常不用计算
break
#所有小于nms_conf阈值的得到保留,其他的都变成0
iou_mask = (ious < nms_conf).float().unsqueeze(1)
image_pred_class[i+1:] *= iou_mask #注意此处是把当前框后面的重合大的变成0
#Remove the non-zero entries
non_zero_ind = torch.nonzero(image_pred_class[:,4]).squeeze()
#取出剩下的候选框
image_pred_class = image_pred_class[non_zero_ind].view(-1,7)
#给预测框的最前面添加一个代表批次中第几张图片的索引号
batch_ind = image_pred_class.new(image_pred_class.size(0), 1).fill_(ind)
seq = batch_ind, image_pred_class
#下面的代码按实际情况灵活处理返回
if not firstflag:
output = torch.cat(seq,1)
firstflag = True
else:
out = torch.cat(seq,1)
output = torch.cat((output,out))
输出的 output 是剩下的候选框n*8
格式: 图片索引号,x1,y1,x2,y2,confidence,classscore,class
def unique(tensor):
tensor_np = tensor.cpu().numpy()
unique_np = np.unique(tensor_np)
unique_tensor = torch.from_numpy(unique_np)
tensor_res = tensor.new(unique_tensor.shape)
tensor_res.copy_(unique_tensor)
return tensor_res
计算IOU的代码,要求box1 ,box2都是(x1,y1,x2,y2)的格式
def bbox_iou(box1, box2):
"""
Returns the IoU of two bounding boxes
"""
#Get the coordinates of bounding boxes
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3]
#get the corrdinates of the intersection rectangle
inter_rect_x1 = torch.max(b1_x1, b2_x1)
inter_rect_y1 = torch.max(b1_y1, b2_y1)
inter_rect_x2 = torch.min(b1_x2, b2_x2)
inter_rect_y2 = torch.min(b1_y2, b2_y2)
#Intersection area
inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * torch.clamp(inter_rect_y2 - inter_rect_y1 + 1, min=0)
#Union Area
b1_area = (b1_x2 - b1_x1 + 1)*(b1_y2 - b1_y1 + 1)
b2_area = (b2_x2 - b2_x1 + 1)*(b2_y2 - b2_y1 + 1)
iou = inter_area / (b1_area + b2_area - inter_area)
return iou