【mmdetection】统计RPN结构的test输出中的roi相对感受野中心的偏移
😄 😆 😊 😃 😏 😍 😘 😚 😳 😌 😆 😁 😉 😜 😝 😀 😗 😙 😛 😴 😟 😦 😧 😮 😬 😕 😯 😑 😒 😅 😓 😥 😩 😔 😞 😖 😨 😰 😣 😢 😭 😂 😲 😱
本文仅用于本人实验使用,如有错误请多包涵
使用的mmdetection版本为
2.14.0
mmdetection中的ATSS模型部分可查看官方文章《轻松掌握 MMDetection 中常用算法(四):ATSS》
注意test流程的部分代码溯源要注意到是继承了BBoxTestMixin
类。
代码内容
在mmdet/models/dense_heads/atss_head.py
中:重新实现_get_bboxes
函数,将统计结果写入log文件中
def _get_bboxes(self,
cls_scores,
bbox_preds,
centernesses,
mlvl_anchors,
img_shapes,
scale_factors,
cfg,
rescale=False,
with_nms=True):
"""【对 _get_bboxes 类进行部分改动,查看回归后的bbox相对原中心点偏移的位置】
Transform outputs for a single batch item into labeled boxes.
Args:
cls_scores (list[Tensor]): Box scores for a single scale level
with shape (N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for a single
scale level with shape (N, num_anchors * 4, H, W).
centernesses (list[Tensor]): Centerness for a single scale level
with shape (N, num_anchors * 1, H, W).
mlvl_anchors (list[Tensor]): Box reference for a single scale level
with shape (num_total_anchors, 4).
在get_bboxes中由anchor_generator.grid_anchors生成的anchors
img_shapes (list[tuple[int]]): Shape of the input image,
list[(height, width, 3)].
scale_factors (list[ndarray]): Scale factor of the image arrange as
(w_scale, h_scale, w_scale, h_scale).
cfg (mmcv.Config | None): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before return boxes.
Default: True.
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
The first item is an (n, 5) tensor, where 5 represent
(tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
The shape of the second tensor in the tuple is (n,), and
each element represents the class label of the corresponding
box.
"""
logger = get_root_logger('INFO')
log_str = ""
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
device = cls_scores[0].device
batch_size = cls_scores[0].shape[0]
# convert to tensor to keep tracing
nms_pre_tensor = torch.tensor(
cfg.get('nms_pre', -1), device=device, dtype=torch.long)
mlvl_bboxes = []
mlvl_scores = []
mlvl_centerness = []
mlvl_topk_inds = []
base_anchors = []
for cls_score, bbox_pred, centerness, anchors in zip(
cls_scores, bbox_preds, centernesses, mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
scores = cls_score.permute(0, 2, 3, 1).reshape(
batch_size, -1, self.cls_out_channels).sigmoid()
centerness = centerness.permute(0, 2, 3,
1).reshape(batch_size,
-1).sigmoid()
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(batch_size, -1, 4)
# Always keep topk op for dynamic input in onnx
if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export()
or scores.shape[-2] > nms_pre_tensor):
from torch import _shape_as_tensor
# keep shape as tensor and get k
num_anchor = _shape_as_tensor(scores)[-2].to(device)
nms_pre = torch.where(nms_pre_tensor < num_anchor,
nms_pre_tensor, num_anchor)
max_scores, _ = (scores * centerness[..., None]).max(-1)
_, topk_inds = max_scores.topk(nms_pre)
anchors = anchors[topk_inds, :]
batch_inds = torch.arange(batch_size).view(
-1, 1).expand_as(topk_inds).long()
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
centerness = centerness[batch_inds, topk_inds]
else:
anchors = anchors.expand_as(bbox_pred)
bboxes = self.bbox_coder.decode(
anchors, bbox_pred, max_shape=img_shapes)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_centerness.append(centerness)
mlvl_topk_inds.append(topk_inds)
base_anchors.append(anchors)
batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
if rescale:
batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor(
scale_factors).unsqueeze(1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
batch_mlvl_centerness = torch.cat(mlvl_centerness, dim=1)
batch_base_anchors = torch.cat(base_anchors, dim=1)
# Set max number of box to be feed into nms in deployment
deploy_nms_pre = cfg.get('deploy_nms_pre', -1)
if deploy_nms_pre > 0 and torch.onnx.is_in_onnx_export():
batch_mlvl_scores, _ = (
batch_mlvl_scores *
batch_mlvl_centerness.unsqueeze(2).expand_as(batch_mlvl_scores)
).max(-1)
_, topk_inds = batch_mlvl_scores.topk(deploy_nms_pre)
batch_inds = torch.arange(batch_size).view(-1,
1).expand_as(topk_inds)
batch_mlvl_scores = batch_mlvl_scores[batch_inds, topk_inds, :]
batch_mlvl_bboxes = batch_mlvl_bboxes[batch_inds, topk_inds, :]
batch_base_anchors = batch_base_anchors[batch_inds, topk_inds, :]
batch_mlvl_centerness = batch_mlvl_centerness[batch_inds,
topk_inds]
# remind that we set FG labels to [0, num_class-1] since mmdet v2.0
# BG cat_id: num_class
padding = batch_mlvl_scores.new_zeros(batch_size,
batch_mlvl_scores.shape[1], 1)
batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1)
if with_nms:
det_results = []
for (mlvl_bboxes, mlvl_scores,
mlvl_centerness, base_anchors) in zip(batch_mlvl_bboxes, batch_mlvl_scores,
batch_mlvl_centerness, batch_base_anchors):
det_bbox, det_label, det_keep = multiclass_nms(
mlvl_bboxes,
mlvl_scores,
cfg.score_thr,
cfg.nms,
cfg.max_per_img,
score_factors=mlvl_centerness,
return_inds=True)
det_anchors = base_anchors[det_keep, :]
det_cx = (det_bbox[:, 0] + det_bbox[:, 2]) * 0.5
det_cy = (det_bbox[:, 1] + det_bbox[:, 3]) * 0.5
det_halfw = (det_bbox[:, 2] - det_bbox[:, 0]) * 0.5
det_halfh = (det_bbox[:, 3] - det_bbox[:, 1]) * 0.5
base_cx = (det_anchors[:, 0] + det_anchors[:, 2]) * 0.5
base_cy = (det_anchors[:, 1] + det_anchors[:, 3]) * 0.5
det_devitex = (base_cx - det_cx) / det_halfw
det_devitey = (base_cy - det_cy) / det_halfh
for i in range(det_devitex.shape[0]):
log_str += f'{i}({det_devitex[i]:2.2f}, {det_devitey[i]:2.2f}) '
log_str += '\n'
det_results.append(tuple([det_bbox, det_label]))
else:
det_results = [
tuple(mlvl_bs)
for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores,
batch_mlvl_centerness)
]
logger.info(log_str)
return det_results