FSOD主要代码分析
FSOD(FewX)原论文代码是基于detectron2写的,本篇代码来自于DAnA论文中对FSOD的对照实验,之后还会写一点DAnA模型的主要代码分析笔记。
1、网络结构总览
查询图像和支持图像由权重共享网络处理,通过attention-rpn关注给定的支持类别来过滤掉其他类别的提案,再用多关系检测器匹配查询提案和支持对象,进行达到进一步筛选。
2、attention-rpn
因为我们希望查询图像每一个位置都和支持图像进行匹配,所以把支持图像特征图作为卷积核对查询图像进行卷积,这样就得到擦洗混图像每一个位置和支持图像特征图的內积,这个內积就可以表示成一种相关性。
# attention rpn
heatmaps = []
for kernel, feat in zip(pos_support_feat.chunk(pos_support_feat.size(0), dim=0), base_feat.chunk(base_feat.size(0), dim=0)):
kernel = kernel.view(pos_support_feat.size(1), 1, pos_support_feat.size(2), pos_support_feat.size(3))
# 支持特征被作为内核,在查询特征图上面滑动
heatmap = F.conv2d(feat, kernel, groups=1024) # [1, 1024, h, w]
heatmap = heatmap.squeeze()
heatmaps += [heatmap]
correlation_feat = torch.stack(heatmaps, 0)
生成表示支持图像特征和查询特征相关性的注意力特征图G,然后将注意力特征图输入RPN网络生成查询候选框。
#fosd.py
# rois [B, RPN_POST_NMS_TOP_N(2000), 5], 5 is [batch_num, x1, y1, x2, y2]rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(correlation_feat, im_info, gt_boxes, num_boxes)
fsod.py中定义rpn:
# define rpn
self.RCNN_rpn = _RPN(self.dout_base_model)
rpn.py:
class _RPN(nn.Module):
...
def forward(self, base_feat, im_info, gt_boxes, num_boxes):
3、多关系匹配
1、全局关系匹配:
## global relation
if self.global_relation:
global_feats = []
current_b = 0
for i, feat in enumerate(pooled_feat.chunk(pooled_feat.size(0), dim=0)):
if i == (current_b + 1) * self.num_of_rois:
current_b += 1
current_support_feat = support_pooled_feat[current_b].unsqueeze(0)
# current_support_feat[0]与一张查询图像多个候选框特征图进行拼接,得到一个数组
concat_feat = torch.cat((feat, current_support_feat), 1) # [1, 2c, 7, 7]
# 将每张支持图像与多个查询图像候选框得到的拼接结果存入数组,得到每张支持图像与多个查询图像候选框拼接结果
global_feats += [concat_feat]
global_feats = torch.cat(global_feats, 0) # [B*128, 2c, 7, 7]
global_feats = self.avgpool_fc(global_feats).squeeze(3).squeeze(2) # [B*128, 2c]
out_fc = F.relu(self.global_fc_1(global_feats), inplace=True)
out_fc = F.relu(self.global_fc_2(out_fc), inplace=True) # [B*128, c]
global_cls_score = self.global_cls_score(out_fc) # [B*128, 2]
2、对应位置关系:
## local correlation
if self.local_correlation:
corr_rois = self.corr_conv(pooled_feat) # [B*128, c, 7, 7]
corr_support = self.corr_conv(support_pooled_feat) # [B, c, 7, 7]
heatmaps = []
# 因为我们希望每一个位置都和支持图像特征进行匹配,所以把支持图像特征图作为卷积核,这样就得到了每一个位置和卷积核进行內积,这个內积就可以表示成一种相关性
for b, grouped_feat in enumerate(corr_rois.chunk(corr_support.size(0), dim=0)):
# 支持图像的特征图保持7x7XC的大小
kernel = corr_support[b].unsqueeze(0) # [1, c, 7, 7]
kernel = kernel.view(kernel.size(1), 1, kernel.size(2), kernel.size(3)) # [c, 1, 7, 7]
# 卷积得到热力图
# 7X7XC大小的支持图像特征图作为kenel在查询图像候选框的特征图上(group_feat)进行卷积,变成1X1XC(变相实现了像素级的相似度计算)
heatmap = F.conv2d(grouped_feat, kernel, groups=self.pool_feat_dim) # [128, c, 1, 1]
heatmap = heatmap.squeeze()
# 将热力图依次以数组方式存放
heatmaps += [heatmap]
# 将热力图数组中的热力图进行拼接
out_corr = torch.cat(heatmaps, 0) # [B*128, c]
# 最后通过一个全连接层生成匹配分数,得到相似度
corr_cls_score = self.corr_cls_score(out_corr) # [B*128, 2]
3、块匹配:
## patch relation
if self.patch_relation:
batch_size = support_pooled_feat.size(0)
patch_feats = []
for b, roi_feats in enumerate(pooled_feat.chunk(batch_size, dim=0)):
current_support_feat = support_pooled_feat[b].expand(roi_feats.size(0), -1, -1, -1) # [128, c, 7, 7]
# 将支持特征图和查询候选框特征图进行拼接
patch = torch.cat((roi_feats, current_support_feat), 1) # [128, 2c, 7, 7]
patch_feats += [patch]
x = torch.cat(patch_feats, dim=0)
# 使用一些列relu和pooling将特征图从7x7下采样到1x1
x = F.relu(self.patch_conv_1(x), inplace=True)
x = self.patch_avgpool(x)
x = F.relu(self.patch_conv_2(x), inplace=True) # 3x3
x = F.relu(self.patch_conv_3(x), inplace=True) # 3x3
x = self.patch_avgpool(x) # 1x1
x = x.squeeze(3).squeeze(2)
# 通过一个全连接层输出相似度分数
patch_cls_score = self.patch_cls_score(x)
产生bbox的预测:
# box regression
bbox_pred = self.RCNN_bbox_pred(self._head_to_tail(pooled_feat)) # [B*128, 4]
将三个模块的分数求和作为最终的匹配分数
# score and prob
if self.patch_relation:
cls_score_all = (global_cls_score + corr_cls_score + patch_cls_score) / self.soft_gamma
else:
cls_score_all = (global_cls_score + corr_cls_score) / self.soft_gamma
cls_prob = F.softmax(cls_score_all, 1) # [B*128, 1]
return bbox_pred, cls_prob, cls_score_all
总结
身为目标检测刚入门的菜鸟,粗浅地写了一点代码阅读笔记,可能存在错误,之后有新的认识会进行修改,如有错误,也欢迎指出。_