(九)深度学习实战 | MMDetection之FCOS(2)


1. 简介

本文将介绍在 M M D e t e c t i o n {\rm MMDetection} MMDetection中实现 F C O S {\rm FCOS} FCOS的几个关键类的实现细节,上一文已经列出了涉及到的几个类。

在这里插入图片描述

图1:FCOS结构


2. FCOS类

链接 F C O S {\rm FCOS} FCOS是一种单阶段的检测器,在 M M D e t e c t i o n {\rm MMDetection} MMDetection中的单阶段检测器继承自类SingleStageDetector,其基类是BaseDetector,该类的主要内容如下:

class BaseDetector(nn.Module, metaclass=ABCMeta):
    """Base class for detectors."""
    def __init__(self):
        super(BaseDetector, self).__init__()
        self.fp16_enabled = False	# FP16,默认为False
    @property
    def with_neck(self):
        # 该模型是否具有网络颈,其他还有类似的函数如with_shared_head、with_bbox和with_mask等
        # property装饰器是Python中的语法,后续可以通过类对象.函数名的形式调用类成员
        return hasattr(self, 'neck') and self.neck is not None
    @abstractmethod	# 抽象方法,继承该类的子类必须实现该方法且该类无法实例化
    def extract_feat(self, imgs):
        pass
	
    def extract_feats(self, imgs):	# 与上一方法不同的是该方法从多幅图像提取特征
        assert isinstance(imgs, list)	# 图像以列表的形式输入
        return [self.extract_feat(img) for img in imgs]	# 调用extract_feat并生成结果列表

    @abstractmethod
    def forward_train(self, imgs, img_metas, **kwargs):
    	# imgs: (N, C, H, W)
    	# img_metas: 图像信息,包括img_shape、scale_factor等
    	# **kwargs: 其他参数
        pass

    @abstractmethod	# 测试
    def simple_test(self, img, img_metas, **kwargs):
        pass

    @abstractmethod	# 测试时使用数据增强
    def aug_test(self, imgs, img_metas, **kwargs):	
        pass

    def init_weights(self, pretrained=None):	# 权重初始化
        if pretrained is not None:
            logger = get_root_logger()	# 打印载入预训练模型的信息
            print_log(f'load model from: {pretrained}', logger=logger)

    def forward_test(self, imgs, img_metas, **kwargs):
        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:	# 类型判断
            if not isinstance(var, list):
                raise TypeError(f'{name} must be a list, but got {type(var)}')
        num_augs = len(imgs)	# 该批次图像数量
        if num_augs != len(img_metas):
            raise ValueError(f'num of augmentations ({len(imgs)}) '
                             f'!= num of image meta ({len(img_metas)})')
        # 仅支持每批次一副图像
        if num_augs == 1:
            if 'proposals' in kwargs:
                kwargs['proposals'] = kwargs['proposals'][0]
            return self.simple_test(imgs[0], img_metas[0], **kwargs)
        else:
            assert imgs[0].size(0) == 1, 'aug test does not support inference with batch size ' \
                                         f'{imgs[0].size(0)}'
            assert 'proposals' not in kwargs
            return self.aug_test(imgs, img_metas, **kwargs)

    @auto_fp16(apply_to=('img', ))
    def forward(self, img, img_metas, return_loss=True, **kwargs):
        if return_loss:
            return self.forward_train(img, img_metas, **kwargs)
        else:
            return self.forward_test(img, img_metas, **kwargs)

    def _parse_losses(self, losses):
    	# 解析网络的输出损失,并将其保存为字典
        log_vars = OrderedDict()
        for loss_name, loss_value in losses.items():
            if isinstance(loss_value, torch.Tensor):
                log_vars[loss_name] = loss_value.mean()
            elif isinstance(loss_value, list):
                log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
            else:
                raise TypeError('{loss_name} is not a tensor or list of tensors')
		# 计算总损失值
        loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
        log_vars['loss'] = loss
        for loss_name, loss_value in log_vars.items():
            # 分布式训练时调用all_reduce函数
            if dist.is_available() and dist.is_initialized():
                loss_value = loss_value.data.clone()
                dist.all_reduce(loss_value.div_(dist.get_world_size()))
            log_vars[loss_name] = loss_value.item()
		# 返回损失信息
        return loss, log_vars

    def train_step(self, data, optimizer):
        # 获得损失信息
        losses = self(**data)
        loss, log_vars = self._parse_losses(losses)
		# 设置输出内容
        outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
		# 返回
        return outputs

    def val_step(self, data, optimizer):
        losses = self(**data)
        loss, log_vars = self._parse_losses(losses)

        outputs = dict(
            loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))

        return outputs

    def show_result(self, img, result, score_thr=0.3, bbox_color='green', text_color='green',
                    thickness=1, font_scale=0.5, win_name='', show=False, wait_time=0, out_file=None):
        # 读取图像并拷贝    
        img = mmcv.imread(img)
        img = img.copy()
        # 是否显示分割结果
        if isinstance(result, tuple):
            bbox_result, segm_result = result
            if isinstance(segm_result, tuple):
                segm_result = segm_result[0]
        else:
            bbox_result, segm_result = result, None
        # 获得边界框坐标
        bboxes = np.vstack(bbox_result)
        # 获得边界框标签
        labels = [
            np.full(bbox.shape[0], i, dtype=np.int32)
            for i, bbox in enumerate(bbox_result)
        ]
        labels = np.concatenate(labels)
        # 如果指定了out_file,则不在当前窗口显示检测结果
        if out_file is not None:
            show = False
        # 绘制检测结果信息
        mmcv.imshow_det_bboxes(
        	img, 	# 图像
        	bboxes,	# 边界框坐标
            labels,	# 边界框标签
            class_names=self.CLASSES,	# 类别名
            score_thr=score_thr,		# 阈值
            bbox_color=bbox_color,		# 边界框颜色
            text_color=text_color,		# 文本颜色
            thickness=thickness,		# 字体粗细
            font_scale=font_scale,		# 字体格式
            win_name=win_name,			# 窗口名称
            show=show,					# 是否显示检测结果
            wait_time=wait_time,		# 设置等待时间
            out_file=out_file)			# 是否保存结果
        if not (show or out_file):
            return img

BaseDetector的派生类SingleStageDetector的主要部分如下:

@DETECTORS.register_module()	# 使用注册器注册DETECTORS
class SingleStageDetector(BaseDetector):	# 单阶段检测器根据backbone提取的特征直接预测目标的边界框的和类别
    def __init__(self, backbone, neck=None, bbox_head=None, train_cfg=None, test_cfg=None, 
    			 pretrained=None):
        super(SingleStageDetector, self).__init__()
        self.backbone = build_backbone(backbone)	# 调用build函数生成backbone
        if neck is not None:	# 如果需要生成neck
            self.neck = build_neck(neck)
        bbox_head.update(train_cfg=train_cfg)	# 更新
        bbox_head.update(test_cfg=test_cfg)		# 更新
        self.bbox_head = build_head(bbox_head)	# 调用build函数生成head
        self.train_cfg = train_cfg	# 训练配置信息
        self.test_cfg = test_cfg	# 测试配置信息
        self.init_weights(pretrained=pretrained)	# 预训练模型

    def init_weights(self, pretrained=None):
    	# 初始化backbone权重
        super(SingleStageDetector, self).init_weights(pretrained)
        self.backbone.init_weights(pretrained=pretrained)
        # 初始化neck权重
        if self.with_neck:
            if isinstance(self.neck, nn.Sequential):
                for m in self.neck:
                    m.init_weights()
            else:
                self.neck.init_weights()
        # 初始化head权重
        self.bbox_head.init_weights()

    def extract_feat(self, img):	# 该方法是基类BaseDetector的抽象方法
    	# 输出经过backbone + neck得到输出
        x = self.backbone(img)
        if self.with_neck:
            x = self.neck(x)
        return x
	# 该方法是基类BaseDetector的抽象方法
    def forward_train(self, img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore=None):
    	# 根据输入img经由backbone + neck得到输出x
        x = self.extract_feat(img)
        # 调用基类BaseDetector的forward_train方法计算损失
        losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore)
        return losses	# 返回

    def simple_test(self, img, img_metas, rescale=False):	# 该方法是基类BaseDetector的抽象方法
    	# 经由backbone + neck得到输出
        x = self.extract_feat(img)
        # 经由head得到输出
        outs = self.bbox_head(x)
        # 将head的输出转换成具体的边界框列表
        bbox_list = self.bbox_head.get_bboxes(*outs, img_metas, rescale=rescale)
        # 导出为ONNX时跳过后处理阶段
        if torch.onnx.is_in_onnx_export():
            return bbox_list
		# 解析边界框列表
        bbox_results = [bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
            			for det_bboxes, det_labels in bbox_list]
        # 返回
        return bbox_results

    def aug_test(self, imgs, img_metas, rescale=False):	# 该方法是基类BaseDetector的抽象方法
        assert hasattr(self.bbox_head, 'aug_test'), '{self.bbox_head.__class__.__name__}' \
            ' does not support test-time augmentation'
		# 输入一组图像获得输出特征
        feats = self.extract_feats(imgs)
        # 调用基类BaseDetector的aug_test方法
        return [self.bbox_head.aug_test(feats, img_metas, rescale=rescale)]

最后是FCOS类:

@DETECTORS.register_module()
class FCOS(SingleStageDetector):
    def __init__(self, backbone, neck, bbox_head, train_cfg=None, test_cfg=None, pretrained=None):
        super(FCOS, self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, pretrained)

3. FCOSHead类

链接 F C O S {\rm FCOS} FCOS是一种无框检测器,该类继承自AnchorFreeHead类,其有两个基类BaseDenseHeadBBoxTestMixin。先来看两个基类的内容:

class BaseDenseHead(nn.Module, metaclass=ABCMeta):
    def __init__(self):
        super(BaseDenseHead, self).__init__()

    @abstractmethod	# 抽象方法,继承该类的子类必须实现该方法
    def loss(self, **kwargs):
    	# 计算损失
        pass

    @abstractmethod	# 抽象方法,继承该类的子类必须实现该方法
    def get_bboxes(self, **kwargs):
        # 将模型的输出转换成边界框
        pass

    def forward_train(self,	
                      x,					# FPN输出的特征
                      img_metas,			# 图像信息
                      gt_bboxes,			# 真实框
                      gt_labels=None,		# 真实框标签,区分AnchorFree和AnchorBased
                      gt_bboxes_ignore=None,# 忽略的真实框
                      proposal_cfg=None,	# 建议框参数,区分AnchorFree和AnchorBased
                      **kwargs):
        # 返回损失以及建议区域
        outs = self(x)
        # 是否有真实框的标签等信息
        if gt_labels is None:
            loss_inputs = outs + (gt_bboxes, img_metas)
        else:
            loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
        # 计算损失
        losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        # 获得建议框
        if proposal_cfg is None:
            return losses
        else:
            proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
            return losses, proposal_list
class BBoxTestMixin(object):
	# 合并
    def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas):
        recovered_bboxes = []
        for bboxes, img_info in zip(aug_bboxes, img_metas):
            img_shape = img_info[0]['img_shape']
            scale_factor = img_info[0]['scale_factor']
            flip = img_info[0]['flip']
            flip_direction = img_info[0]['flip_direction']
            bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
                                       flip_direction)
            recovered_bboxes.append(bboxes)
        bboxes = torch.cat(recovered_bboxes, dim=0)
        if aug_scores is None:
            return bboxes
        else:
            scores = torch.cat(aug_scores, dim=0)
            return bboxes, scores

    def aug_test_bboxes(self, feats, img_metas, rescale=False):

        gb_sig = signature(self.get_bboxes)
        gb_args = [p.name for p in gb_sig.parameters.values()]
        gbs_sig = signature(self._get_bboxes_single)
        gbs_args = [p.name for p in gbs_sig.parameters.values()]
        assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \
            f'{self.__class__.__name__} does not support test-time augmentation'

        aug_bboxes = []
        aug_scores = []
        aug_factors = []
        for x, img_meta in zip(feats, img_metas):

            outs = self.forward(x)
            bbox_inputs = outs + (img_meta, self.test_cfg, False, False)
            bbox_outputs = self.get_bboxes(*bbox_inputs)[0]
            aug_bboxes.append(bbox_outputs[0])
            aug_scores.append(bbox_outputs[1])

            if len(bbox_outputs) >= 3:
                aug_factors.append(bbox_outputs[2])

        merged_bboxes, merged_scores = self.merge_aug_bboxes(aug_bboxes, aug_scores, img_metas)
        merged_factors = torch.cat(aug_factors, dim=0) if aug_factors else None
        det_bboxes, det_labels = multiclass_nms(
            merged_bboxes,
            merged_scores,
            self.test_cfg.score_thr,
            self.test_cfg.nms,
            self.test_cfg.max_per_img,
            score_factors=merged_factors)

        if rescale:
            _det_bboxes = det_bboxes
        else:
            _det_bboxes = det_bboxes.clone()
            _det_bboxes[:, :4] *= det_bboxes.new_tensor(img_metas[0][0]['scale_factor'])
        bbox_results = bbox2result(_det_bboxes, det_labels, self.num_classes)
        return bbox_results

AnchorFreeHead的主要内容如下:

@HEADS.register_module()	# 使用注册器注册HEADS
class AnchorFreeHead(BaseDenseHead, BBoxTestMixin):
    def __init__(self,
                 num_classes,				# 类别数
                 in_channels,				# 输入通道数
                 feat_channels=256,			# 特征图通道数
                 stacked_convs=4,			# 堆叠的卷积层数
                 strides=(4, 8, 16, 32, 64),# 下采样倍数
                 dcn_on_last_conv=False,	# 最后一层是否使用DCN
                 conv_bias='auto',			# 卷积层的偏置参数
                 loss_cls=dict(				# 分类分支损失函数相关参数
                     type='FocalLoss',		# 分类分支使用FocalLoss
                     use_sigmoid=True,		# 是否使用sigmoid		
                     gamma=2.0,				# FocalLoss参数1
                     alpha=0.25,			# FocalLoss参数2
                     loss_weight=1.0),		# 分类损失权重
                 loss_bbox=dict(			# 回归分支损失函数相关参数
                 	 type='IoULoss', 		# 回归分支使用IoULoss
                	 loss_weight=1.0),		# 回归损失权重
                 conv_cfg=None,				
                 norm_cfg=None,
                 train_cfg=None,
                 test_cfg=None):
        super(AnchorFreeHead, self).__init__()
        self.num_classes = num_classes
        self.cls_out_channels = num_classes	# 分支分支输出特征图的通道数等于类别数
        self.in_channels = in_channels
        self.feat_channels = feat_channels
        self.stacked_convs = stacked_convs
        self.strides = strides
        self.dcn_on_last_conv = dcn_on_last_conv
        assert conv_bias == 'auto' or isinstance(conv_bias, bool)
        self.conv_bias = conv_bias
        self.loss_cls = build_loss(loss_cls)	# 根据build函数计算分类损失
        self.loss_bbox = build_loss(loss_bbox)	# 根据build函数计算回归损失
        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.fp16_enabled = False
        self._init_layers()	# 层初始化

    def _init_layers(self):
        # 初始化
        self._init_cls_convs()
        self._init_reg_convs()
        self._init_predictor()

    def _init_cls_convs(self):
    	# 建立分类分支
        self.cls_convs = nn.ModuleList()
        for i in range(self.stacked_convs):	# 根据堆叠数展开
            chn = self.in_channels if i == 0 else self.feat_channels	# 确定输入通道数
            if self.dcn_on_last_conv and i == self.stacked_convs - 1:	# 最后一层是否替换成DCN
                conv_cfg = dict(type='DCNv2')
            else:
                conv_cfg = self.conv_cfg
            self.cls_convs.append(
                ConvModule(	# 包含conv + norm + activation操作
                    chn,
                    self.feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    conv_cfg=conv_cfg,
                    norm_cfg=self.norm_cfg,
                    bias=self.conv_bias))

    def _init_reg_convs(self):
        # 建立回归分支
        self.reg_convs = nn.ModuleList()
        for i in range(self.stacked_convs):	# 根据堆叠数展开
            chn = self.in_channels if i == 0 else self.feat_channels	# 确定输入通道数
            if self.dcn_on_last_conv and i == self.stacked_convs - 1:	# 最后一层是否使用DCN
                conv_cfg = dict(type='DCNv2')
            else:
                conv_cfg = self.conv_cfg
            self.reg_convs.append(
                ConvModule(	# 包含conv + norm + activation操作
                    chn,
                    self.feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    conv_cfg=conv_cfg,
                    norm_cfg=self.norm_cfg,
                    bias=self.conv_bias))

    def _init_predictor(self):
        # 建立预测分支,即分类分支和回归分支的结尾处
        self.conv_cls = nn.Conv2d(self.feat_channels, self.cls_out_channels, 3, padding=1)
        self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)

    def init_weights(self):
        # 初始化权重
        for m in self.cls_convs:	# 初始化分类分支权重
            if isinstance(m.conv, nn.Conv2d):
                normal_init(m.conv, std=0.01)
        for m in self.reg_convs:	# 初始化分类分支权重
            if isinstance(m.conv, nn.Conv2d):
                normal_init(m.conv, std=0.01)
        bias_cls = bias_init_with_prob(0.01)
        normal_init(self.conv_cls, std=0.01, bias=bias_cls)
        normal_init(self.conv_reg, std=0.01)

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        # 更新模型参数,以便能够使用以前版本的checkpoints
        version = local_metadata.get('version', None)
        if version is None:
            # 字典与以前版本的不同,如fcos_cls变成了conv_cls
            bbox_head_keys = [k for k in state_dict.keys() if k.startswith(prefix)]
            # 定义列表分别存放以前和现在的键值
            ori_predictor_keys = []
            new_predictor_keys = []
            # 如fcos_cls或fcos_reg
            for key in bbox_head_keys:
                ori_predictor_keys.append(key)
                key = key.split('.')
                conv_name = None
                if key[1].endswith('cls'):
                    conv_name = 'conv_cls'
                elif key[1].endswith('reg'):
                    conv_name = 'conv_reg'
                elif key[1].endswith('centerness'):
                    conv_name = 'conv_centerness'
                else:
                    assert NotImplementedError
                if conv_name is not None:
                    key[1] = conv_name
                    new_predictor_keys.append('.'.join(key))
                else:
                    ori_predictor_keys.pop(-1)
            # 使用新的内容更新字典
            for i in range(len(new_predictor_keys)):
                state_dict[new_predictor_keys[i]] = state_dict.pop(ori_predictor_keys[i])
        super()._load_from_state_dict(state_dict, prefix, local_metadata,
                                      strict, missing_keys, unexpected_keys,error_msgs)

    def forward(self, feats):
        # 输入为一个4维张量,返回分类置信度和边界框预测
        return multi_apply(self.forward_single, feats)[:2]

    def forward_single(self, x):
        cls_feat = x
        reg_feat = x
		# 分类分支预测
        for cls_layer in self.cls_convs:
            cls_feat = cls_layer(cls_feat)
        cls_score = self.conv_cls(cls_feat)
		# 回归分支预测
        for reg_layer in self.reg_convs:
            reg_feat = reg_layer(reg_feat)
        bbox_pred = self.conv_reg(reg_feat)
        # 返回
        return cls_score, bbox_pred, cls_feat, reg_feat

    @abstractmethod	# 抽象方法
    @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
    def loss(self, cls_scores, bbox_preds, gt_bboxes, gt_labels,img_metas, gt_bboxes_ignore=None):
		# 计算head的损失值,输入包括预测和真实标注
        raise NotImplementedError

    @abstractmethod	# 抽象方法
    @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
    def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg=None, rescale=None):
		# 将网络的输出转换成边界框的预测
        raise NotImplementedError

    @abstractmethod	# 抽象方法
    def get_targets(self, points, gt_bboxes_list, gt_labels_list):
        # 计算分类、回归和center-ness的优化目标
        raise NotImplementedError

    def _get_points_single(self, featmap_size, stride, dtype, device, flatten=False):
        # 获得单个特征图上所有点的坐标
        h, w = featmap_size
        # x_range = Tensor([0,1,...,w-1])
        x_range = torch.arange(w, dtype=dtype, device=device)
        # y_range = Tensor([0,1,...,h-1])
        y_range = torch.arange(h, dtype=dtype, device=device)
        # 得到的第一个值为按列展开的结果,第二个值为按行展开的结果,且形状均是(h,w)
        y, x = torch.meshgrid(y_range, x_range)
        # 是否展开
        if flatten:
            y = y.flatten()
            x = x.flatten()
        # 返回
        return y, x

    def get_points(self, featmap_sizes, dtype, device, flatten=False):
        # 同时在多个特征图上获得点的坐标
        mlvl_points = []
        for i in range(len(featmap_sizes)):
            mlvl_points.append(self._get_points_single(featmap_sizes[i], self.strides[i],
                               dtype, device, flatten))
    	# 以列表的形式返回
        return mlvl_points

    def aug_test(self, feats, img_metas, rescale=False):
        return self.aug_test_bboxes(feats, img_metas, rescale=rescale)

最后,FCOSHead的关键内容为:

@HEADS.register_module()	# 使用注册器注册HEADS
class FCOSHead(AnchorFreeHead):
    def __init__(self,
                 num_classes,	# 类别数
                 in_channels,	# 输入通道数
                 # FCOS在FPN的每一层规定了回归的范围,超过该范围的不在该层回归
                 regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 512), (512, INF)),
                 center_sampling=False,		# 使用使用中心采样
                 center_sample_radius=1.5,	# 使用中心采样后中心区域的大小
                 norm_on_bbox=False,		# 回归分支正则化
                 centerness_on_reg=False,	# 是否在回归分支共享center-ness
                 loss_cls=dict(	# 分类分支FocalLoss参数
                     type='FocalLoss',
                     use_sigmoid=True,
                     gamma=2.0,
                     alpha=0.25,
                     loss_weight=1.0),
                 loss_bbox=dict(type='IoULoss', loss_weight=1.0),	# 回归分支IoULoss参数	
                 loss_centerness=dict(	# center-ness分支CrossEntropyLoss参数
                     type='CrossEntropyLoss',
                     use_sigmoid=True,
                     loss_weight=1.0),
                 norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
                 **kwargs):
        self.regress_ranges = regress_ranges
        self.center_sampling = center_sampling
        self.center_sample_radius = center_sample_radius
        self.norm_on_bbox = norm_on_bbox
        self.centerness_on_reg = centerness_on_reg
        super().__init__(num_classes, in_channels, loss_cls=loss_cls, loss_bbox=loss_bbox,
            norm_cfg=norm_cfg, **kwargs)
        self.loss_centerness = build_loss(loss_centerness)	# 调用build函数计算center-ness分支的损失

    def _init_layers(self):
        # 初始化head中的层
        super()._init_layers()
        self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1)
        self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])

    def init_weights(self):
        # 初始化head中的权重
        super().init_weights()
        normal_init(self.conv_centerness, std=0.01)

    def forward(self, feats):
    	# 和基类AnchorFreeHead类似
        return multi_apply(self.forward_single, feats, self.scales, self.strides)

    def forward_single(self, x, scale, stride):
        # 调用父类AnchorFreeHead的forward_single函数得到结果,FCOS多了一个center-ness分支
        cls_score, bbox_pred, cls_feat, reg_feat = super().forward_single(x)
        # 在分类分支还是回归分支共享center-ness
        if self.centerness_on_reg:
            centerness = self.conv_centerness(reg_feat)
        else:
            centerness = self.conv_centerness(cls_feat)
        # scale=[1.0,1.0,1.0,1.0,1.0]
        bbox_pred = scale(bbox_pred).float()
        if self.norm_on_bbox:
            bbox_pred = F.relu(bbox_pred)
            if not self.training:	# 是否处于训练阶段
            	# 乘以stride将预测框放大到与原图对应的尺寸
                bbox_pred *= stride
        else:
            bbox_pred = bbox_pred.exp()
        # 返回分类得分、边界框预测和centerness值
        return cls_score, bbox_pred, centerness

上面两个forward函数定义了前向传播的部分,下面来看FCOSHead类中的其他重要函数。第一个是根据模型输出将三个分支的结果转换成损失函数的输入格式,其中也定义两个函数_get_bboxes_singleget_bboxes分别处理特征金字塔的单层和所有层:

    @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
    def get_bboxes(self, cls_scores, bbox_preds, centernesses, img_metas, cfg=None,
                   rescale=False, with_nms=True):
        # 将网络的输出解析成边界框信息
        assert len(cls_scores) == len(bbox_preds)
        # FPN输出层数
        num_levels = len(cls_scores)
        # 特征图大小
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        # 得到所有特征图上的点坐标
        mlvl_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device)
        result_list = []
        for img_id in range(len(img_metas)):
        	# 解析分类得分
            cls_score_list = [cls_scores[i][img_id].detach() for i in range(num_levels)]
            # 解析边界框信息
            bbox_pred_list = [bbox_preds[i][img_id].detach() for i in range(num_levels)]
            # 解析centerness值
            centerness_pred_list = [centernesses[i][img_id].detach() for i in range(num_levels)]
            img_shape = img_metas[img_id]['img_shape']
            scale_factor = img_metas[img_id]['scale_factor']
            # 调用_get_bboxes_single针对单幅图像生成边界框信息
            det_bboxes = self._get_bboxes_single(
                cls_score_list, bbox_pred_list, centerness_pred_list,
                mlvl_points, img_shape, scale_factor, cfg, rescale, with_nms)
            result_list.append(det_bboxes)
        # 返回
        return result_list

    def _get_bboxes_single(self, cls_scores, bbox_preds, centernesses, mlvl_points, img_shape,
                           scale_factor, cfg, rescale=False, with_nms=True):
        cfg = self.test_cfg if cfg is None else cfg
        assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
        # 存放多层的结果
        mlvl_bboxes = []
        mlvl_scores = []
        mlvl_centerness = []
        # 遍历
        for cls_score, bbox_pred, centerness, points in zip(
                cls_scores, bbox_preds, centernesses, mlvl_points):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            # 转换维度、reshape、sigmoid等操作
            scores = cls_score.permute(1, 2, 0).reshape(-1, self.cls_out_channels).sigmoid()
            centerness = centerness.permute(1, 2, 0).reshape(-1).sigmoid()
            bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
            nms_pre = cfg.get('nms_pre', -1)
            # 如果框的数量大于nms_pre则只取前nms_pre大得分的样本,得分通过score * centerness计算
            if nms_pre > 0 and scores.shape[0] > nms_pre:
                max_scores, _ = (scores * centerness[:, None]).max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                points = points[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                scores = scores[topk_inds, :]
                centerness = centerness[topk_inds]
            # 根据预测结果解析
            bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)
            mlvl_centerness.append(centerness)
        # 拼接结果
        mlvl_bboxes = torch.cat(mlvl_bboxes)
        # 缩放
        if rescale:
            mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
        # 拼接
        mlvl_scores = torch.cat(mlvl_scores)
        # 填充
        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
        mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
        mlvl_centerness = torch.cat(mlvl_centerness)
		# NMS
        if with_nms:
            det_bboxes, det_labels = multiclass_nms(
                mlvl_bboxes,
                mlvl_scores,
                cfg.score_thr,
                cfg.nms,
                cfg.max_per_img,
                score_factors=mlvl_centerness)
            return det_bboxes, det_labels
        else:
            return mlvl_bboxes, mlvl_scores, mlvl_centerness

然后定义函数_get_points_single获得特征金字塔中单幅特征图上的点坐标,结合父类AnchorFreeHeadget_points函数即可获得特征金字塔中所有层的点坐标。

    def _get_points_single(self, featmap_size, stride, dtype, device, flatten=False):
        # 调用父类AnchorFreeHead_get_points_single函数获得单幅特征图上的点坐标
        y, x = super()._get_points_single(featmap_size, stride, dtype, device)
        # 论文中的公式(floor(s/2)+xs, floor(s/2)+ys),points即为该点映射回原图所在的位置
        points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride), dim=-1) + stride // 2
        return points

接着,计算各分支的回归目标,包括分类分支、回归分支和 c e n t e r n e s s {\rm centerness} centerness分支:

    def get_targets(self, points, gt_bboxes_list, gt_labels_list):
    	# 均为特征金字塔层数
        assert len(points) == len(self.regress_ranges)
        num_levels = len(points)
        # 扩展,points=(num_levels,num_points,2) => expanded_regress_ranges=(num_points,2)
        # 这里将points的第一维消去,使用点的数目即可标识出该点属于哪一层,然后以第一层为例:
        # [[-1,64],[-1,64],...,[-1,64]]=(num_points_of_0_level,2)
        expanded_regress_ranges = [
            points[i].new_tensor(self.regress_ranges[i])[None].expand_as(points[i]) 
            	for i in range(num_levels)
        ]
        # 拼接,concat_regress_ranges=(num_points,2)
        concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
        # concat_points=(num_levels,num_points,2)
        concat_points = torch.cat(points, dim=0)
        # 每幅图像/每层的点数量,即遍历第一个变量num_levels
        num_points = [center.size(0) for center in points]
        # 针对每幅输入图像获得真实信息
        labels_list, bbox_targets_list = multi_apply(
            self._get_target_single,
            gt_bboxes_list,
            gt_labels_list,
            points=concat_points,
            regress_ranges=concat_regress_ranges,
            num_points_per_lvl=num_points)
        # 切分得到每一层每一幅图像的labels
        labels_list = [labels.split(num_points, 0) for labels in labels_list]
        # 切分得到每一层每一幅图像的边界框回归目标
        bbox_targets_list = [bbox_targets.split(num_points, 0) for bbox_targets in bbox_targets_list]
        # 拼接所有层
        concat_lvl_labels = []
        concat_lvl_bbox_targets = []
        # 遍历所有层,并拼接
        for i in range(num_levels):
            concat_lvl_labels.append(
                torch.cat([labels[i] for labels in labels_list]))
            bbox_targets = torch.cat(
                [bbox_targets[i] for bbox_targets in bbox_targets_list])
            # 是否对标准化
            if self.norm_on_bbox:
                bbox_targets = bbox_targets / self.strides[i]
            concat_lvl_bbox_targets.append(bbox_targets)
        # 返回
        return concat_lvl_labels, concat_lvl_bbox_targets

    def _get_target_single(self, gt_bboxes, gt_labels, points, regress_ranges, num_points_per_lvl):
 		# 针对单幅图像计算分类和回归目标
        num_points = points.size(0)
        num_gts = gt_labels.size(0)
        if num_gts == 0:
            return gt_labels.new_full((num_points,), self.num_classes), \
                   gt_bboxes.new_zeros((num_points, 4))
		# 计算真实框的面积,gt_bboxes=(num_gts,4)
        areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (gt_bboxes[:, 3] - gt_bboxes[:, 1])
        areas = areas[None].repeat(num_points, 1)
        # 扩展regress使其与真实框数量对齐
        regress_ranges = regress_ranges[:, None, :].expand(num_points, num_gts, 2)
        # [None]相当于扩充一个维度,(1,num_gts,4) => (num_points,num_gts,4)
        gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
        # points=(num_points,2),2分别表示x和y坐标
        xs, ys = points[:, 0], points[:, 1]
        # xs=(num_points,) => (num_points,num_gts)
        xs = xs[:, None].expand(num_points, num_gts)
        ys = ys[:, None].expand(num_points, num_gts)
		# 获得各点到真实框各边界的距离,即论文中所说的回归目标,left/right/top/bottom=(num_points,num_gts,1)
        left = xs - gt_bboxes[..., 0]
        right = gt_bboxes[..., 2] - xs
        top = ys - gt_bboxes[..., 1]
        bottom = gt_bboxes[..., 3] - ys
        # 堆叠,box_targets=(num_points,num_gts,4)
        bbox_targets = torch.stack((left, top, right, bottom), -1)
		# 中心采样表示仅将落在真实框中心区域某范围内的点作为正样本
        if self.center_sampling:
            # 定义中心区域的半径
            radius = self.center_sample_radius
            # 真实框的中心点坐标,center_xs=(num_points,num_gts,1)
            center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2
            center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2
            center_gts = torch.zeros_like(gt_bboxes)
            # stride.shape=(num_points,num_gts,1)
            stride = center_xs.new_zeros(center_xs.shape)
            # 遍历所有特征金字塔特征图上的所有点
            lvl_begin = 0
            for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl):
            	# 特征金字塔下一层的点的索引
                lvl_end = lvl_begin + num_points_lvl
                # 由于FPN每一层的特征图大小不同,所以stride在对应层的值也不同
                # self.strides=[8,16,32,64,128],self.strides[lvl_idx]*radius=[12,24,48,96,192]
                # 第一层点的stride对应12,第二层的stride对应24,...,第五层的stride对应192
                # stride=(num_points,num_gts,1)
                stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius
                # 处理下一层
                lvl_begin = lvl_end
			# center_xs.shape=(num_points,num_gts,1),stride=(num_points,num_gts,1)
			# center_xs各真实框的中心点横坐标,stride表示各层特征图上的点对应的值[12,24,48,96,192]
			# 下面式子对应每一层的运算,如第一层的真实框中心点坐标减去第一层对应的stride为12,其他层类似
			# 以第一层为例说明下面式子的含义
            x_mins = center_xs - stride	# 中心点位置向左移动12个像素后的位置
            y_mins = center_ys - stride	# 中心点位置向上移动12个像素后的位置
            x_maxs = center_xs + stride	# 中心点位置向右移动12个像素后的位置
            y_maxs = center_ys + stride	# 中心点位置向下移动12个像素后的位置
            # 定义回归的有效中心区域,torch.where的第一个参数表示条件,如果满足条件则返回第二个参数的内容,
            # 否则返回第三个参数的内容。
            # gt_bboxes(x_min,y_min,x_max,y_max),以下面第一个式子为例说明,如果真实框的中心点向左移动了12个像素
            # 后的位置大于真实框的左边界,即没有位于真实框外边,则取该位置;否则取真实框的左边界位置。其他式子含义类似
            center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0], x_mins, gt_bboxes[..., 0])
            center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1], y_mins, gt_bboxes[..., 1])
            center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2], gt_bboxes[..., 2], x_maxs)
            center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3], gt_bboxes[..., 3], y_maxs)
            # 该点距离中心区域各边界的距离
            cb_dist_left = xs - center_gts[..., 0]
            cb_dist_right = center_gts[..., 2] - xs
            cb_dist_top = ys - center_gts[..., 1]
            cb_dist_bottom = center_gts[..., 3] - ys
            # center_bbox=(num_points,num_gts,4)
            center_bbox = torch.stack((cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
            # 使用中心采样,落入真实框中心区域的看作正样本,即最短距离都大于零的话,那么该点的映射位置肯定位于中心区域
            inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
        else:
            # 不使用中心采样,落入真实框内即看作正样本,同上
            inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0
        # 限制每个位置的回归目标范围
        max_regress_distance = bbox_targets.max(-1)[0]
        # regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 512), (512, INF))
        inside_regress_range = ((max_regress_distance >= regress_ranges[..., 0])
            					& (max_regress_distance <= regress_ranges[..., 1]))
        # 如果该位置映射后的位置对应多个目标,取较小的那个作为回归目标,这里首先将无效的位置置为INF方便后续筛选
        areas[inside_gt_bbox_mask == 0] = INF
        areas[inside_regress_range == 0] = INF
        # 获得最小面积以及对应的索引
        min_area, min_area_inds = areas.min(dim=1)
		# 取对应id的真实框标签
        labels = gt_labels[min_area_inds]
        # 设置为背景
        labels[min_area == INF] = self.num_classes
        # 取对应id的回归目标
        bbox_targets = bbox_targets[range(num_points), min_area_inds]
        # 返回对应的分类回归目标和边界框回归目标
        return labels, bbox_targets

	def centerness_target(self, pos_bbox_targets):
        # 仅计算正样本位置的centerness
        left_right = pos_bbox_targets[:, [0, 2]]
        top_bottom = pos_bbox_targets[:, [1, 3]]
        # centerness的计算公式
        centerness_targets = (
            left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
                top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
        # 开平方根并返回
        return torch.sqrt(centerness_targets)

最后重要的函数是loss函数的实现,其实现的功能是将所有涉及损失函数计算的变量转换成满足损失函数输入的格式:

    @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
    def loss(self, cls_scores, bbox_preds, centernesses, gt_bboxes, gt_labels, img_metas, 
             gt_bboxes_ignore=None):
        # 计算各分支的损失
        # cls_scores=(N,num_points*num_classes,H,W)
        # bbox_preds=(N,num_points*4,H,W)
        # centernesses=(N,num_points*1,H,W)
        # gt_bboxes=(num_gts,4)
        # gt_labels=(num_gts,1)
        assert len(cls_scores) == len(bbox_preds) == len(centernesses)
        # 得到FPN各层特征图的大小,即获得h和w的值
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
        # 得到特征图上的所有点,all_level_points=(num_levels,num_points,2)
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype, bbox_preds[0].device)
        # 得到回归目标,labels=(num_levels,num_points,1),bbox_targets=(num_levels,num_points,4)
        labels, bbox_targets = self.get_targets(all_level_points, gt_bboxes, gt_labels)
		# 特征图数,即Batch_Size
        num_imgs = cls_scores[0].size(0)
        # 展开,使用permute函数重排张量
        # cls_score=(N,C*NP,H,W) => (N,H,W,C*NP) => (NHWNP, C),其他参数变量类似
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        # 拼接函数
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_centerness = torch.cat(flatten_centerness)
        flatten_labels = torch.cat(labels)
        flatten_bbox_targets = torch.cat(bbox_targets)
        # 将点坐标repeat以便与边界框相对应
        flatten_points = torch.cat([points.repeat(num_imgs, 1) for points in all_level_points])
        # 前景类别id: [0,num_classes-1],背景类别id: num_classes
        bg_class_ind = self.num_classes
        # 正样本id,即前景类别id
        pos_inds = ((flatten_labels >= 0) & (flatten_labels < bg_class_ind)).nonzero().reshape(-1)
        # 正样本数量
        num_pos = len(pos_inds)
        # 计算分类损失
        loss_cls = self.loss_cls(flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs)
		# 根据正样本id确定需要处理的边界框以及centerness
        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]
		# 正样本数量大于零
        if num_pos > 0:
        	# 正样本边界框的回归目标
            pos_bbox_targets = flatten_bbox_targets[pos_inds]
            # 正样本center-ness的回归目标
            pos_centerness_targets = self.centerness_target(pos_bbox_targets)
            # 正样本位置,在FCOS中将位置看作样本
            pos_points = flatten_points[pos_inds]
            # 解码,将预测内容解码成实际的边界框以计算IoULoss
            pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
            pos_decoded_target_preds = distance2bbox(pos_points, pos_bbox_targets)
            # 计算回归损失,IoULoss
            loss_bbox = self.loss_bbox(
                pos_decoded_bbox_preds,
                pos_decoded_target_preds,
                weight=pos_centerness_targets,
                avg_factor=pos_centerness_targets.sum())
            # 计算centerness分支损失,CrossEntropyLoss
            loss_centerness = self.loss_centerness(pos_centerness, pos_centerness_targets)
        else:
            loss_bbox = pos_bbox_preds.sum()
            loss_centerness = pos_centerness.sum()
		# 以字典形式返回各部分损失
        return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_centerness=loss_centerness)

3. 总结

本文介绍的最重要的两个类的成员函数是lossget_targets。在loss中主要实现的功能是将模型得到的输出以及标注的真实内容解码以满足损失函数的输入格式,在get_targets中主要实现的功能是对每个正样本找到合适的回归目标。


4. 参考

  1. https://github.com/open-mmlab/mmdetection.


  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值