小样本学习论文解读—FSOD:Few-Shot Object Detection with Attention-RPN and Multi-Relation Detector

在这里插入图片描述

《Few-Shot Object Detection with Attention-RPN and Multi-Relation Detector》

关键词:
few-shot object detection

总结:
We propose a general few-shot object detection network that learns the matching metric between image pairs based on the Faster R-CNN framework equipped with our novel attention RPN and multi-relation detector trained using our contrastive training strategy

三大创新:

  • Attention-RPN: at the early stage where the proposal quality is significantly enhanced
  • Multi-Relation Detector: at the later stage which suppresses and filters out false detection in the confusing background.
  • Contrastive Training strategy: exploit the similarity between the few shot support set and query set to detect novel objects while suppressing false detection in the background

原有few-shot object detection中RPN存在的问题:

  • Potential bounding boxes can easily miss unseen objects, or else many false detections in the background can be produced. We believe this is caused by the inappropriate low scores of good bounding boxes output from a region proposal network (RPN) making a novel object hard to be detected.
  • without any support image information, the RPN will be aimlessly active in every potential object with high objectness score even though they do not belong to the support category, thus burdening the subsequent classification task of the detector with a large number of irrelevant objects

问题定义:
Given a support image s c s_{c} sc with a close-up of the target object and a query image q c q_{c} qc which potentially contains objects of the support category c c c, the task is to find all the target objects belonging to the support category in the query and label them with tight bounding boxes. If the support set contains N N N categories and K K K examples for each category the problem is dubbed N N N-way K K K-shot detection.

模型架构:

在这里插入图片描述
we build a weight-shared framework that consists of multiple branches, where one branch is for the query set and the others are for the support set (for simplicity, we only show one support branch in the figure). The query branch of the weight-shared framework is a Faster R-CNN network, which contains RPN and detector.

Attention-RPN

We introduce support information to RPN through the attention mechanism to guide the RPN to produce relevant proposals while suppressing proposals in other categories
注意使用的是depth-wise卷积:
在这里插入图片描述

Multi-Relation Detector

We propose a novel multi-relation detector to effectively measure the similarity between proposal boxes from the query and the support objects.
在这里插入图片描述

  • global-relation head: learn a deep embedding for global matching
  • local-correlation head: learn the pixel-wise and depth-wise cross correlation between support and query proposals
  • patch-relation head: learn a deep non-linear metric for patch matching

接下来结合官方开源代码与论文,做训练阶段与测试阶段的解释

推理阶段
average feature

We obtain all the support feature through the weight-shared network and use the average feature across all the supports belonging to the same category as its support feature

support_dir = './support_dir'

support_file_name = os.path.join(support_dir, 'support_feature.pkl')
if not os.path.exists(support_file_name):
support_path = './datasets/coco/10_shot_support_df.pkl' 
support_df = pd.read_pickle(support_path) # support data 读取

metadata = MetadataCatalog.get('coco_2017_train')
# unmap the category mapping ids for COCO
reverse_id_mapper = lambda dataset_id: metadata.thing_dataset_id_to_contiguous_id[dataset_id]  # noqa
support_df['category_id'] = support_df['category_id'].map(reverse_id_mapper)

support_dict = {'res4_avg': {}, 'res5_avg': {}}
for cls in support_df['category_id'].unique():
    support_cls_df = support_df.loc[support_df['category_id'] == cls, :].reset_index()
    support_data_all = []
    support_box_all = []

    for index, support_img_df in support_cls_df.iterrows():
    	# file_path存储的是已经裁剪完的support目标
        img_path = os.path.join('./datasets/coco', support_img_df['file_path']
        support_data = utils.read_image(img_path, format='BGR')
        support_data = torch.as_tensor(np.ascontiguousarray(support_data.transpose(2, 0, 1)))
        support_data_all.append(support_data)
        
		# support image的box 信息
        support_box = support_img_df['support_box'] #
        support_box_all.append(Boxes([support_box]).to(self.device))

    # support images crop图片读取
    support_images = [x.to(self.device) for x in support_data_all]
    support_images = [(x - self.pixel_mean) / self.pixel_std for x in support_images]
    support_images = ImageList.from_tensors(support_images, self.backbone.size_divisibility)
    # 所有support目标通过backbone得到特征
    support_features = self.backbone(support_images.tensor) 
	# 由于support images的padding等操作,因此还需要做一个roi pooling
    res4_pooled = self.roi_heads.roi_pooling(support_features, support_box_all)
    # 对应的是论文中的描述:对所有support做平均特征
    res4_avg = res4_pooled.mean(0, True)
    # 对应的是论文中的描述:depth-wise 全局平均池化 记做res4_avg
    res4_avg = res4_avg.mean(dim=[2,3], keepdim=True) # [1,1024,1,1]
    support_dict['res4_avg'][cls] = res4_avg.detach().cpu().data
    
	# 论文中未描述:其实_shared_roi_transform也是roi pooling 就很奇怪???
    res5_feature = self.roi_heads._shared_roi_transform(
    	[support_features[f] for f in self.in_features], support_box_all)
    res5_avg = res5_feature.mean(0, True) #[1,2048,7,7]
    support_dict['res5_avg'][cls] = res5_avg.detach().cpu().data
Inference
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值