测试过程
在少样本目标检测任务中,训练过程每次输入的是一个episode
,也就是有一张query image
需要检测对象,以及20张support images
提供类别信息。在support images
中,一共有2个类别,一个类别10张图片。在训练时会确保一个类和query image
相同,一个类别和它不同。训练时模型的任务就是找到query image
中属于support images
类别的对象,然而在测试的时候每张图片需要和20个类别的support features
做对比。
总体流程如下:
- 每个类别
suppport feature
与query feature
结合得到100个proposals
,这100个proposals
只和该类support feature
对比得到类别概率和坐标偏移,如果最终这个proposals
能够作为预测结果保留,则它的类别就是这个support
类。 - 这20个类别对应的
proposals
与对应的support feature
对比得到其概率得分和坐标偏移,将正类概率作为置信度,保留符合要求的proposals
。 - 在剩下的
proposals
中再进行一次非极大值抑制,将剩下的不超过100个proposals
作为最终的预测结果。
特征提取
为了测试结果的稳定性,预测挑选好200张support images,一共20个类别,每个类别10张图片,然后使用训练好的backbone对图片提取特征。
生成proposals
此时已经获得20个类别的support feature,将每个support feature作为attention与query image结合后输入到PRN中,每个RPN会生成100个proposals,所以一共生成2000个proposals。
support_proposals_dict = {} # 关键字是类别id,保存对应的proposals
support_box_features_dict = {} # 关键字是类别id,保存对应的supprot box特征
proposal_num_dict = {} # 关键字是类别id,保存对应的proposals数量
# 遍历20个类别的support feature
for cls_id, res4_avg in self.support_dict['res4_avg'].items():
query_images = ImageList.from_tensors([images[0]]) # one query image
query_features_res4 = features['res4'] # one query feature for attention rpn
query_features = {'res4': query_features_res4} # one query feature for rcnn
# support branch ##################################
support_box_features = self.support_dict['res5_avg'][cls_id]
correlation = F.conv2d(query_features_res4, res4_avg.permute(1,0,2,3), groups=1024) # attention map
support_correlation = {'res4': correlation} # attention map for attention rpn
proposals, _ = self.proposal_generator(query_images, support_correlation, None)
support_proposals_dict[cls_id] = proposals
support_box_features_dict[cls_id] = support_box_features
if cls_id not in proposal_num_dict.keys():
proposal_num_dict[cls_id] = []
proposal_num_dict[cls_id].append(len(proposals[0]))
results, _ = self.roi_heads.eval_with_support(query_images, query_features, support_proposals_dict, support_box_features_dict)
计算Proposals类别及坐标偏移
上面已经提到,每个类别的support feature 和query image结合输入RPN生成100个proposals。在Faster-Rcnn检测框架的第二阶段,这100个proposas只与该类support feature进行对比,得到分类得分和坐标偏移。
for cls_id in cls_ls:
support_box_features = support_box_features_dict[cls_id] # 该类别对应的support feature
support_box_features_res3 = support_box_features_dict_res3[cls_id]
query_features = box_features[cnt*100:(cnt+1)*100] # 该类别对应的100个proposals feature
query_features_res3 = box_features_res3[cnt*100:(cnt+1)*100]
# 计算这100个propsals与该类表征的关系得分和坐标
pred_class_logits, pred_proposal_deltas = self.box_predictor(query_features, support_box_features,
query_features_res3, support_box_features_res3)
full_scores_ls.append(pred_class_logits)
full_bboxes_ls.append(pred_proposal_deltas)
full_cls_ls.append(torch.full_like(pred_class_logits[:, 0].unsqueeze(-1), cls_id).to(torch.int8))
del query_features
del support_box_features
cnt += 1
class_logits = torch.cat(full_scores_ls, dim=0) # 2000个proposals是正类和负类的得分 [2000,2]
proposal_deltas = torch.cat(full_bboxes_ls, dim=0) # 2000个proposals的位置 [2000,4]
pred_cls = torch.cat(full_cls_ls, dim=0) #.unsqueeze(-1) # # 2000个proposals对应的类别
predictions = class_logits, proposal_deltas
推断模型预测结果
得到每个proposal的类别得分和坐标偏移之后,可以推断其是正负类的概率以及最后的预测框坐标。
# 根据预测的proposals偏移量以及propsals位置得到模型预测框
boxes = self.predict_boxes(predictions, proposals)
# 经过softmax得到每个proposals是正类和负类的概率
scores = self.predict_probs(predictions, proposals)
num_inst_per_image = [len(p) for p in proposals]
# 每个proposal对应的类别
pred_cls = pred_cls.split(num_inst_per_image, dim=0)
# 每个proposals的尺寸
image_shapes = [x.image_size for x in proposals]
将proposals为正类的概率作为置信度,设置阈值0.05只保留置信度高于0.05的proposals。在剩下的proposals中再进行一次非极大值抑制,最终只保留不超过100个proposals作为预测结果。
# proposals正类的概率
scores = scores[:, :-1] # [2000,1]
# 类别数 20
cls_num = pred_cls.unique().shape[0]
# 每个类别对应的proposals数目 100
box_num = int(scores.shape[0] / cls_num)
# scores[i,j]表示这个proposal是类别j的概率
scores = scores.reshape(cls_num, box_num).permute(1, 0) # [100,20]
boxes = boxes.reshape(cls_num, box_num, 4).permute(1, 0, 2).reshape(box_num, -1) # [100,80]
pred_cls = pred_cls.reshape(cls_num, box_num).permute(1, 0) # [100,20]
# 回归类别数 20
num_bbox_reg_classes = boxes.shape[1] // 4
# Convert to Boxes to use the `clip` function ...
boxes = Boxes(boxes.reshape(-1, 4))
boxes.clip(image_shape)
boxes = boxes.tensor.view(-1, num_bbox_reg_classes, 4) # [100,20,4]
# 找到置信度高于阈值的proposals 0.05
filter_mask = scores > score_thresh # [100,20]
# R行表示2000个proposals中有R个置信度符合要求,每行的数据表示该proposal的二维索引
# 第一列表示proposal的索引,第二列表示其对应的类别索引
filter_inds = filter_mask.nonzero() # [R,2]
# 符合要求的proposals位置
if num_bbox_reg_classes == 1:
boxes = boxes[filter_inds[:, 0], 0]
else:
boxes = boxes[filter_mask] # [R,4]
scores = scores[filter_mask] # 184
pred_cls = pred_cls[filter_mask] # 184
# 进行非极大值抑制,返回保留的proposals的索引
keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh)
if topk_per_image >= 0:
keep = keep[:topk_per_image]
boxes, scores, filter_inds, pred_cls = boxes[keep], scores[keep], filter_inds[keep], pred_cls[keep]
result = Instances(image_shape)
result.pred_boxes = Boxes(boxes)
result.scores = score
result.pred_classes = pred_cls
代码源于FewX