【mmdetection】【目标检测】用ATSSAssigner替换MaskRCNN原有的assigner

【mmdetection】【目标检测】用ATSSAssigner替换MaskRCNN原有的assigner

😄 😆 😊 😃 😏 😍 😘 😚 😳 😌 😆 😁 😉 😜 😝 😀 😗 😙 😛 😴 😟 😦 😧 😮 😬 😕 😯 😑 😒 😅 😓 😥 😩 😔 😞 😖 😨 😰 😣 😢 😭 😂 😲 😱

本文仅用于本人实验使用,如有错误请多包涵

使用的mmdetection版本为2.14.0


mmdetection中的MaskrcnnFasterrcnn代码部分细节可查看官方文章《轻松掌握 MMDetection 中常用算法(二):Faster R-CNN|Mask R-CNN》
ATSS模型部分可查看官方文章《轻松掌握 MMDetection 中常用算法(四):ATSS》

Mask rcnn代码结构

  • detector: MaskRCNN->TwoStageDetector->BaseDetector
  • backbone: ResNet
  • neck: FPN
  • rpn_head: RPNHead->AnchorHead->BaseDenseHead;BBoxTestMixin
  • roi_head: StandardRoIHead->BaseRoIHead;BBoxTestMixin;MaskTestMixin

代码实现

替换RPN中的assigner

  • 需修改配置文件的对应部分

mmdet.models.dense_heads.anchor_head.py中:

  • 参考ATSS模型代码的用法引入get_num_level_anchors_inside函数。
  • 重新实现_get_targets_singleget_targets函数。
    def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
        split_inside_flags = torch.split(inside_flags, num_level_anchors)
        num_level_anchors_inside = [
            int(flags.sum()) for flags in split_inside_flags
        ]
        return num_level_anchors_inside

    def _get_targets_single(self,
                            flat_anchors,
                            valid_flags,
                            num_level_anchors,
                            gt_bboxes,
                            gt_bboxes_ignore,
                            gt_labels,
                            img_meta,
                            label_channels=1,
                            unmap_outputs=True):
        """Compute regression and classification targets for anchors in a
        single image.

        Args:
            flat_anchors (Tensor): Multi-level anchors of the image, which are
                concatenated into a single tensor of shape (num_anchors ,4)
            valid_flags (Tensor): Multi level valid flags of the image,
                which are concatenated into a single tensor of
                    shape (num_anchors,).
            gt_bboxes (Tensor): Ground truth bboxes of the image,
                shape (num_gts, 4).
            gt_bboxes_ignore (Tensor): Ground truth bboxes to be
                ignored, shape (num_ignored_gts, 4).
            img_meta (dict): Meta info of the image.
            gt_labels (Tensor): Ground truth labels of each box,
                shape (num_gts,).
            label_channels (int): Channel of label.
            unmap_outputs (bool): Whether to map outputs back to the original
                set of anchors.

        Returns:
            tuple:
                labels_list (list[Tensor]): Labels of each level
                label_weights_list (list[Tensor]): Label weights of each level
                bbox_targets_list (list[Tensor]): BBox targets of each level
                bbox_weights_list (list[Tensor]): BBox weights of each level
                num_total_pos (int): Number of positive samples in all images
                num_total_neg (int): Number of negative samples in all images
        """
        # 检查anchor是否越界
        inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
                                           img_meta['img_shape'][:2],
                                           self.train_cfg.allowed_border)
        if not inside_flags.any():
            return (None, ) * 7
        # assign gt and sample anchors
        anchors = flat_anchors[inside_flags, :]

        num_level_anchors_inside = self.get_num_level_anchors_inside(
            num_level_anchors, inside_flags)
        # MaskRCNN 默认 assigner 为 MaxIoUAssigner
        # assign_result: {gt_inds, max_overlaps, num_gts(int), num_preds(int), labels(None), info({})}
        #   gt_inds: [num_anchors, ](int)
        #   max_overlaps: [num_anchors, ](float)

        # assign_result = self.assigner.assign(
        #     anchors, gt_bboxes, gt_bboxes_ignore,
        #     None if self.sampling else gt_labels)
        # 修改为ATSS中的ATSSAsigner
        assign_result = self.assigner.assign(anchors, num_level_anchors_inside,
                                             gt_bboxes, gt_bboxes_ignore,
                                             gt_labels)
        # MaskRCNN 默认 sampler 为 RandomSampler
        # sampling_result: {pos_bboxes, pos_inds, pos_gt_bboxes(=gt_bboxes), pos_assigned_gt_inds, neg_bboxes, neg_inds, ...}
        sampling_result = self.sampler.sample(assign_result, anchors,
                                              gt_bboxes)

        num_valid_anchors = anchors.shape[0]
        # torch.zeros_like:生成和括号内变量维度维度一致的全是零的内容
        bbox_targets = torch.zeros_like(anchors)
        bbox_weights = torch.zeros_like(anchors)
        # new_full(size, fill_value) 用fill_value填充
        labels = anchors.new_full((num_valid_anchors, ),
                                  self.num_classes,
                                  dtype=torch.long)
        label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)

        pos_inds = sampling_result.pos_inds
        neg_inds = sampling_result.neg_inds
        if len(pos_inds) > 0:
            if not self.reg_decoded_bbox:
                pos_bbox_targets = self.bbox_coder.encode(
                    sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
            else:
                pos_bbox_targets = sampling_result.pos_gt_bboxes
            bbox_targets[pos_inds, :] = pos_bbox_targets
            bbox_weights[pos_inds, :] = 1.0
            if gt_labels is None:
                # Only rpn gives gt_labels as None
                # Foreground is the first class since v2.5.0
                labels[pos_inds] = 0
            else:
                labels[pos_inds] = gt_labels[
                    sampling_result.pos_assigned_gt_inds]
            if self.train_cfg.pos_weight <= 0:
                label_weights[pos_inds] = 1.0
            else:
                label_weights[pos_inds] = self.train_cfg.pos_weight
        if len(neg_inds) > 0:
            label_weights[neg_inds] = 1.0

        # map up to original set of anchors
        if unmap_outputs:
            num_total_anchors = flat_anchors.size(0)
            labels = unmap(
                labels, num_total_anchors, inside_flags,
                fill=self.num_classes)  # fill bg label
            label_weights = unmap(label_weights, num_total_anchors,
                                  inside_flags)
            bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
            bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)

        return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
                neg_inds, sampling_result)

    def get_targets(self,
                    anchor_list,
                    valid_flag_list,
                    gt_bboxes_list,
                    img_metas,
                    gt_bboxes_ignore_list=None,
                    gt_labels_list=None,
                    label_channels=1,
                    unmap_outputs=True,
                    return_sampling_results=False):
        """Compute regression and classification targets for anchors in
        multiple images.
        【用与MaskRCNN中在RPN部分采用(参考AtssHead的用法)】
        Args:
            anchor_list (list[list[Tensor]]): Multi level anchors of each
                image. The outer list indicates images, and the inner list
                corresponds to feature levels of the image. Each element of
                the inner list is a tensor of shape (num_anchors, 4).
                【batch_size * N(num of feature levels) * [num_anchors, 4]】
            valid_flag_list (list[list[Tensor]]): Multi level valid flags of
                each image. The outer list indicates images, and the inner list
                corresponds to feature levels of the image. Each element of
                the inner list is a tensor of shape (num_anchors, )
                【batch_size * N(num of feature levels) * num_anchors(boolean)】
            gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
                【batch_size * [num_gts, 4]】
            img_metas (list[dict]): Meta info of each image.
                【batch_size * {}】
            gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
                ignored.
                【usually None】
            gt_labels_list (list[Tensor]): Ground truth labels of each box.
                【usually None】
            label_channels (int): Channel of label.
            unmap_outputs (bool): Whether to map outputs back to the original
                set of anchors.

        Returns:
            tuple: Usually returns a tuple containing learning targets.

                - labels_list (list[Tensor]): Labels of each level.
                - label_weights_list (list[Tensor]): Label weights of each \
                    level.
                - bbox_targets_list (list[Tensor]): BBox targets of each level.
                - bbox_weights_list (list[Tensor]): BBox weights of each level.
                - num_total_pos (int): Number of positive samples in all \
                    images.
                - num_total_neg (int): Number of negative samples in all \
                    images.
            additional_returns: This function enables user-defined returns from
                `self._get_targets_single`. These returns are currently refined
                to properties at each feature map (i.e. having HxW dimension).
                The results will be concatenated after the end
        """
        num_imgs = len(img_metas)
        assert len(anchor_list) == len(valid_flag_list) == num_imgs

        # anchor number of multi levels
        num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
        num_level_anchors_list = [num_level_anchors] * num_imgs
        # concat all level anchors to a single tensor
        concat_anchor_list = []
        concat_valid_flag_list = []
        for i in range(num_imgs):
            assert len(anchor_list[i]) == len(valid_flag_list[i])
            concat_anchor_list.append(torch.cat(anchor_list[i]))
            concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))

        # compute targets for each image
        if gt_bboxes_ignore_list is None:
            gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
        if gt_labels_list is None:
            gt_labels_list = [None for _ in range(num_imgs)]
        # 这个 multi_apply 跳转过去也看不到,就把格式放到这里了
        # concat_anchor_list: batch_size * [num_anchors(sum), 4]
        # concat_valid_flag_list: batch_size * num_anchors(sum; boolean)
        # 其余参数未作改动
        results = multi_apply(
            self._get_targets_single,
            concat_anchor_list,
            concat_valid_flag_list,
            num_level_anchors_list,
            gt_bboxes_list,
            gt_bboxes_ignore_list,
            gt_labels_list,
            img_metas,
            label_channels=label_channels,
            unmap_outputs=unmap_outputs)
        # all_labels: batch_size * num_anchors(sum)(int)
        # all_label_weights: batch_size * num_anchors(sum)(float; almost 0)
        # all_bbox_targets: batch_size * [num_anchors(sum), 4](float; almost [0, 0, 0, 0])
        # all_bbox_weights: batch_size * [num_anchors(sum), 4](float; almost [0, 0, 0, 0])
        # pos_inds_list: batch_size * num_pos(int)
        # neg_inds_list: batch_size * num_neg(int)
        # sampling_results_list: batch_size * {} TODO:暂时没看这个是干什么用的,其默认也是不进行使用
        (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
         pos_inds_list, neg_inds_list, sampling_results_list) = results[:7]
        rest_results = list(results[7:])  # user-added return values
        # no valid anchors
        if any([labels is None for labels in all_labels]):
            return None
        # sampled anchors of all images
        num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
        num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
        # split targets to a list w.r.t. multiple levels
        labels_list = images_to_levels(all_labels, num_level_anchors)
        label_weights_list = images_to_levels(all_label_weights,
                                              num_level_anchors)
        bbox_targets_list = images_to_levels(all_bbox_targets,
                                             num_level_anchors)
        bbox_weights_list = images_to_levels(all_bbox_weights,
                                             num_level_anchors)
        res = (labels_list, label_weights_list, bbox_targets_list,
               bbox_weights_list, num_total_pos, num_total_neg)
        if return_sampling_results:
            res = res + (sampling_results_list, )
        for i, r in enumerate(rest_results):  # user-added return values
            rest_results[i] = images_to_levels(r, num_level_anchors)

        return res + tuple(rest_results)

替换RCNN中的assigner

  • 替换config文件中的对应部分。

mmdet/models/roi_heads/standard_roi_head.py中:

  • 引入map_roi_levels函数
  • 重新实现forward_train函数
    def map_roi_levels(self, rois, num_levels):
        """Map rois to corresponding feature levels by scales.

        - scale < finest_scale * 2: level 0
        - finest_scale * 2 <= scale < finest_scale * 4: level 1
        - finest_scale * 4 <= scale < finest_scale * 8: level 2
        - scale >= finest_scale * 8: level 3

        Args:
            rois (Tensor): Input RoIs, shape (k, 5).
            num_levels (int): Total level number.

        Returns:
            Tensor: Level index (0-based) of each RoI, shape (k, )
        """
        finest_scale = 56

        # scale = torch.sqrt((rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2]))
        scale = torch.sqrt((rois[:, 2] - rois[:, 0]) * (rois[:, 3] - rois[:, 1]))
        target_lvls = torch.floor(torch.log2(scale / finest_scale + 1e-6))
        target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
        return target_lvls

    def forward_train(self,
                      x,
                      img_metas,
                      proposal_list,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None,
                      gt_masks=None):
        """
        Args:
            x (list[Tensor]): list of multi-level img features.
            img_metas (list[dict]): list of image info dict where each dict
                has: 'img_shape', 'scale_factor', 'flip', and may also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                `mmdet/datasets/pipelines/formatting.py:Collect`.
            proposals (list[Tensors]): list of region proposals.
            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (list[Tensor]): class indices corresponding to each box
            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.
            gt_masks (None | Tensor) : true segmentation masks for each box
                used if the architecture supports a segmentation task.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        # assign gts and sample proposals
        if self.with_bbox or self.with_mask:
            num_imgs = len(img_metas)
            if gt_bboxes_ignore is None:
                gt_bboxes_ignore = [None for _ in range(num_imgs)]
            sampling_results = []
            for i in range(num_imgs):
                # # 原参数 MaxIoUAssigner
                # Return: assign_result: {gt_inds, max_overlaps, num_gts(int), num_preds(int), labels(None), info({})}
                #           gt_inds: [num_anchors, ](int)
                #           max_overlaps: [num_anchors, ](float)

                # assign_result = self.bbox_assigner.assign(
                #     proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
                #     gt_labels[i])

                # # 修改为 ATSSAssigner (参考ATSS的用法还要经 inside_flags 处理得到 num_level_anchors_inside,这里暂时不考虑inside_flags)
                target_lvls = self.map_roi_levels(proposal_list[i], self.mask_roi_extractor.num_inputs)
                _, lvls_ind = torch.sort(target_lvls)
                proposal_list[i] = proposal_list[i][lvls_ind, :]
                num_level_anchors = []
                for j in range(self.mask_roi_extractor.num_inputs):
                    num_level_anchors.append(torch.sum(target_lvls == j))

                assign_result = self.bbox_assigner.assign(
                    proposal_list[i], num_level_anchors, gt_bboxes[i], gt_bboxes_ignore[i],
                    gt_labels[i])
                sampling_result = self.bbox_sampler.sample(
                    assign_result,
                    proposal_list[i],
                    gt_bboxes[i],
                    gt_labels[i],
                    feats=[lvl_feat[i][None] for lvl_feat in x])
                sampling_results.append(sampling_result)

        losses = dict()
        # bbox head forward and loss
        if self.with_bbox:
            bbox_results = self._bbox_forward_train(x, sampling_results,
                                                    gt_bboxes, gt_labels,
                                                    img_metas)
            losses.update(bbox_results['loss_bbox'])

        # mask head forward and loss
        if self.with_mask:
            mask_results = self._mask_forward_train(x, sampling_results,
                                                    bbox_results['bbox_feats'],
                                                    gt_masks, img_metas)
            losses.update(mask_results['loss_mask'])

        return losses

代码逻辑

RPN部分的实现由于ATSS模型本身即为FCOS改进的单阶段模型,可直接参考。

而RCNN的部分要替换时需要注意,ATSS算法在为每个roi分配iou阈值时需要获得该roi所在的feature map的level,在RPN中该值可以直接获得,而在RCNN中本人采用的是类似RoiAlign之前的即包装在SingleRoIExtractor中的对应feature map的level的操作,即引入其中的map_roi_levels计算roi所在level。


匆忙整理,如有错误,敬请见谅

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值