《Few-Shot Object Detection with Attention-RPN and Multi-Relation Detector》
关键词:
few-shot object detection
总结:
We propose a general few-shot object detection network that learns the matching metric between image pairs based on the Faster R-CNN framework equipped with our novel attention RPN and multi-relation detector trained using our contrastive training strategy
三大创新:
- Attention-RPN: at the early stage where the proposal quality is significantly enhanced
- Multi-Relation Detector: at the later stage which suppresses and filters out false detection in the confusing background.
- Contrastive Training strategy: exploit the similarity between the few shot support set and query set to detect novel objects while suppressing false detection in the background
原有few-shot object detection中RPN存在的问题:
- Potential bounding boxes can easily miss unseen objects, or else many false detections in the background can be produced. We believe this is caused by the inappropriate low scores of good bounding boxes output from a region proposal network (RPN) making a novel object hard to be detected.
- without any support image information, the RPN will be aimlessly active in every potential object with high objectness score even though they do not belong to the support category, thus burdening the subsequent classification task of the detector with a large number of irrelevant objects
问题定义:
Given a support image
s
c
s_{c}
sc with a close-up of the target object and a query image
q
c
q_{c}
qc which potentially contains objects of the support category
c
c
c, the task is to find all the target objects belonging to the support category in the query and label them with tight bounding boxes. If the support set contains
N
N
N categories and
K
K
K examples for each category the problem is dubbed
N
N
N-way
K
K
K-shot detection.
模型架构:
we build a weight-shared framework that consists of multiple branches, where one branch is for the query set and the others are for the support set (for simplicity, we only show one support branch in the figure). The query branch of the weight-shared framework is a Faster R-CNN network, which contains RPN and detector.
Attention-RPN
We introduce support information to RPN through the attention mechanism to guide the RPN to produce relevant proposals while suppressing proposals in other categories
注意使用的是depth-wise卷积:
Multi-Relation Detector
We propose a novel multi-relation detector to effectively measure the similarity between proposal boxes from the query and the support objects.
- global-relation head: learn a deep embedding for global matching
- local-correlation head: learn the pixel-wise and depth-wise cross correlation between support and query proposals
- patch-relation head: learn a deep non-linear metric for patch matching
接下来结合官方开源代码与论文,做训练阶段与测试阶段的解释
推理阶段
average feature
We obtain all the support feature through the weight-shared network and use the average feature across all the supports belonging to the same category as its support feature
support_dir = './support_dir'
support_file_name = os.path.join(support_dir, 'support_feature.pkl')
if not os.path.exists(support_file_name):
support_path = './datasets/coco/10_shot_support_df.pkl'
support_df = pd.read_pickle(support_path) # support data 读取
metadata = MetadataCatalog.get('coco_2017_train')
# unmap the category mapping ids for COCO
reverse_id_mapper = lambda dataset_id: metadata.thing_dataset_id_to_contiguous_id[dataset_id] # noqa
support_df['category_id'] = support_df['category_id'].map(reverse_id_mapper)
support_dict = {'res4_avg': {}, 'res5_avg': {}}
for cls in support_df['category_id'].unique():
support_cls_df = support_df.loc[support_df['category_id'] == cls, :].reset_index()
support_data_all = []
support_box_all = []
for index, support_img_df in support_cls_df.iterrows():
# file_path存储的是已经裁剪完的support目标
img_path = os.path.join('./datasets/coco', support_img_df['file_path']
support_data = utils.read_image(img_path, format='BGR')
support_data = torch.as_tensor(np.ascontiguousarray(support_data.transpose(2, 0, 1)))
support_data_all.append(support_data)
# support image的box 信息
support_box = support_img_df['support_box'] #
support_box_all.append(Boxes([support_box]).to(self.device))
# support images crop图片读取
support_images = [x.to(self.device) for x in support_data_all]
support_images = [(x - self.pixel_mean) / self.pixel_std for x in support_images]
support_images = ImageList.from_tensors(support_images, self.backbone.size_divisibility)
# 所有support目标通过backbone得到特征
support_features = self.backbone(support_images.tensor)
# 由于support images的padding等操作,因此还需要做一个roi pooling
res4_pooled = self.roi_heads.roi_pooling(support_features, support_box_all)
# 对应的是论文中的描述:对所有support做平均特征
res4_avg = res4_pooled.mean(0, True)
# 对应的是论文中的描述:depth-wise 全局平均池化 记做res4_avg
res4_avg = res4_avg.mean(dim=[2,3], keepdim=True) # [1,1024,1,1]
support_dict['res4_avg'][cls] = res4_avg.detach().cpu().data
# 论文中未描述:其实_shared_roi_transform也是roi pooling 就很奇怪???
res5_feature = self.roi_heads._shared_roi_transform(
[support_features[f] for f in self.in_features], support_box_all)
res5_avg = res5_feature.mean(0, True) #[1,2048,7,7]
support_dict['res5_avg'][cls] = res5_avg.detach().cpu().data