FSOD主要代码分析

本文详细分析了FSOD(Few-shotObjectDetection)的代码实现,主要涉及attention-rpn模块,通过支持图像特征对查询图像进行卷积,生成相关性特征图。接着介绍了多关系匹配,包括全局、对应位置和块匹配三种方式,用于提高匹配精度。最后,综合各个模块的分数预测bbox并进行分类。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

FSOD主要代码分析

FSOD(FewX)原论文代码是基于detectron2写的,本篇代码来自于DAnA论文中对FSOD的对照实验,之后还会写一点DAnA模型的主要代码分析笔记。

源码地址

1、网络结构总览

​ 查询图像和支持图像由权重共享网络处理,通过attention-rpn关注给定的支持类别来过滤掉其他类别的提案,再用多关系检测器匹配查询提案和支持对象,进行达到进一步筛选。

在这里插入图片描述

2、attention-rpn

在这里插入图片描述

​ 因为我们希望查询图像每一个位置都和支持图像进行匹配,所以把支持图像特征图作为卷积核对查询图像进行卷积,这样就得到擦洗混图像每一个位置和支持图像特征图的內积,这个內积就可以表示成一种相关性。

       # attention rpn
       heatmaps = []
        for kernel, feat in zip(pos_support_feat.chunk(pos_support_feat.size(0), dim=0), base_feat.chunk(base_feat.size(0), dim=0)):
            kernel = kernel.view(pos_support_feat.size(1), 1, pos_support_feat.size(2), pos_support_feat.size(3))
            # 支持特征被作为内核,在查询特征图上面滑动
            heatmap = F.conv2d(feat, kernel, groups=1024)  # [1, 1024, h, w]
            heatmap = heatmap.squeeze()
            heatmaps += [heatmap]
        correlation_feat = torch.stack(heatmaps, 0)

​ 生成表示支持图像特征和查询特征相关性的注意力特征图G,然后将注意力特征图输入RPN网络生成查询候选框。

#fosd.py
# rois [B, RPN_POST_NMS_TOP_N(2000), 5], 5 is [batch_num, x1, y1, x2, y2]rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(correlation_feat, im_info, gt_boxes, num_boxes)

fsod.py中定义rpn:

        # define rpn
        self.RCNN_rpn = _RPN(self.dout_base_model)

rpn.py:

class _RPN(nn.Module):
...
	  def forward(self, base_feat, im_info, gt_boxes, num_boxes):

3、多关系匹配

在这里插入图片描述

1、全局关系匹配:
  ## global relation
        if self.global_relation:
            global_feats = []
            current_b = 0
            for i, feat in enumerate(pooled_feat.chunk(pooled_feat.size(0), dim=0)):
                if i == (current_b + 1) * self.num_of_rois:
                    current_b += 1
                current_support_feat = support_pooled_feat[current_b].unsqueeze(0)
                # current_support_feat[0]与一张查询图像多个候选框特征图进行拼接,得到一个数组
                concat_feat = torch.cat((feat, current_support_feat), 1)  # [1, 2c, 7, 7]
                # 将每张支持图像与多个查询图像候选框得到的拼接结果存入数组,得到每张支持图像与多个查询图像候选框拼接结果
                global_feats += [concat_feat]
            global_feats = torch.cat(global_feats, 0)  # [B*128, 2c, 7, 7]
            global_feats = self.avgpool_fc(global_feats).squeeze(3).squeeze(2)  # [B*128, 2c]
            out_fc = F.relu(self.global_fc_1(global_feats), inplace=True)
            out_fc = F.relu(self.global_fc_2(out_fc), inplace=True)  # [B*128, c]
            global_cls_score = self.global_cls_score(out_fc)  # [B*128, 2]
2、对应位置关系:
 		## local correlation
 		if self.local_correlation:
            corr_rois = self.corr_conv(pooled_feat)  # [B*128, c, 7, 7]
            corr_support = self.corr_conv(support_pooled_feat)  # [B, c, 7, 7]
            heatmaps = []
            # 因为我们希望每一个位置都和支持图像特征进行匹配,所以把支持图像特征图作为卷积核,这样就得到了每一个位置和卷积核进行內积,这个內积就可以表示成一种相关性
            for b, grouped_feat in enumerate(corr_rois.chunk(corr_support.size(0), dim=0)):
                # 支持图像的特征图保持7x7XC的大小
                kernel = corr_support[b].unsqueeze(0)  # [1, c, 7, 7]
                kernel = kernel.view(kernel.size(1), 1, kernel.size(2), kernel.size(3))  # [c, 1, 7, 7]
                # 卷积得到热力图
                # 7X7XC大小的支持图像特征图作为kenel在查询图像候选框的特征图上(group_feat)进行卷积,变成1X1XC(变相实现了像素级的相似度计算)
                heatmap = F.conv2d(grouped_feat, kernel, groups=self.pool_feat_dim)  # [128, c, 1, 1]
                heatmap = heatmap.squeeze()
                # 将热力图依次以数组方式存放
                heatmaps += [heatmap]
                # 将热力图数组中的热力图进行拼接
            out_corr = torch.cat(heatmaps, 0)  # [B*128, c]
            # 最后通过一个全连接层生成匹配分数,得到相似度
            corr_cls_score = self.corr_cls_score(out_corr)  # [B*128, 2]
3、块匹配:
        ## patch relation
        if self.patch_relation:
            batch_size = support_pooled_feat.size(0)
            patch_feats = []
            for b, roi_feats in enumerate(pooled_feat.chunk(batch_size, dim=0)):
                current_support_feat = support_pooled_feat[b].expand(roi_feats.size(0), -1, -1, -1)  # [128, c, 7, 7]
                # 将支持特征图和查询候选框特征图进行拼接
                patch = torch.cat((roi_feats, current_support_feat), 1)  # [128, 2c, 7, 7]
                patch_feats += [patch]
            x = torch.cat(patch_feats, dim=0)
            # 使用一些列relu和pooling将特征图从7x7下采样到1x1
            x = F.relu(self.patch_conv_1(x), inplace=True)
            x = self.patch_avgpool(x)
            x = F.relu(self.patch_conv_2(x), inplace=True) # 3x3
            x = F.relu(self.patch_conv_3(x), inplace=True) # 3x3
            x = self.patch_avgpool(x) # 1x1
            x = x.squeeze(3).squeeze(2)
            # 通过一个全连接层输出相似度分数
            patch_cls_score = self.patch_cls_score(x)

产生bbox的预测:

 # box regression
        bbox_pred = self.RCNN_bbox_pred(self._head_to_tail(pooled_feat))  # [B*128, 4]

将三个模块的分数求和作为最终的匹配分数

        # score and prob
        if self.patch_relation:
            cls_score_all = (global_cls_score + corr_cls_score + patch_cls_score) / self.soft_gamma
        else:
            cls_score_all = (global_cls_score + corr_cls_score) / self.soft_gamma
        cls_prob = F.softmax(cls_score_all, 1)  # [B*128, 1]

        return bbox_pred, cls_prob, cls_score_all 

总结

身为目标检测刚入门的菜鸟,粗浅地写了一点代码阅读笔记,可能存在错误,之后有新的认识会进行修改,如有错误,也欢迎指出。_

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值