SOLO(ECCV 2020)论文与代码解读

paper:SOLO: Segmenting Objects by Locations

official implementation:GitHub - aim-uofa/AdelaiDet: AdelaiDet is an open source toolbox for multiple instance-level detection and recognition tasks.

third-party implementation:mmdetection/mmdet/models/dense_heads/solo_head.py at main · open-mmlab/mmdetection · GitHub

背景

实例分割(Instance segmentation)是计算机视觉中的一个重要任务,其目的是在图像中对每个独立的物体实例进行分割和标注。与语义分割不同,实例分割不仅需要对物体类别进行分类,还需要在同类物体中区分不同的实例。这使得实例分割任务比其他密集预测任务(如语义分割)更具挑战性。

存在的问题

传统的实例分割方法可以分为两类:自上而下的方法(top-down)和自下而上的方法(bottom-up)。自上而下的方法先检测物体的边界框,然后在每个边界框内进行实例分割;而自下而上的方法则学习像素间的关联,通过嵌入向量将像素分组为不同的实例。这些方法通常步骤繁多,依赖于精确的边界框检测或像素嵌入学习和后处理。

本文的创新点

本文提出了一种新的实例分割算法SOLO(Segment Objects by Locations), 提出了一种全新的实例分割视角,将实例分割任务转换为一个分类可以解决的问题。具体而言,SOLO通过引入“实例类别”的概念,根据实例的中心位置和大小将类别分配给每个像素,从而将实例分割任务简化为一个分类问题。这样可以直接在像素级别生成实例掩码,而无需边界框或像素嵌入学习和分组处理。

方法介绍

以COCO验证集为例,总共包含36780个物体,其中98.3%的物体pairs之间的中心点距离超过30个像素,剩下的1.7%的物体对中,40.5%大小超过1.5倍。总之,在大多数情况下,图像中的两个物体要么中心位置不同,要么大小不同。因此作者想到能否通过中心位置和物体大小来直接区分实例。

Problem Formulation

SOLO的框架如图2所示,核心思想是将实例分割问题转化为两个同时存在的类别感知预测问题。具体来说,SOLO将输入图像均匀地划分为 \(S\times S\) 个网格,如果一个对象的中心落入某个网格中,那么该网格负责 1)预测语义类别 2)分割该实例对象

Semantic Category 对每个网格,SOLO预测一个 \(C\) 维的输出表示语义类别的概率,其中 \(C\) 是总类别数。如果我们将输入图片划分成 \(S\times S\) 个网格,则输出为 \(S\times S\times C\),如图2上所示。这种设计是基于假设每个网格只属于单个实例。

Instance Mask 在预测语义类别的同时,每个positive网格还要生成对应的实例mask。对于划分为 \(S\times S\) 个网格的输入图片,最多有 \(S^2\) 个mask。我们将显式地在第三个维度(通道)编码这些mask。具体来说,instance mask分支输出维度为 \(H\times W\times S^2\),如图2下所示,其中第 \(k\) 个通道预测网格 \((i,j)\) 的mask,其中 \(k=i\cdot S+j\)。这样语义类别和类别不可知的mask之间就建立了一对一的对应关系。

预测实例mask的一种直接的方法是利用全卷积网络FCN,但传统的卷积在一定程度上是空间不变的,这种特性对于分类任务来说是有益的,因为它引入了鲁棒性。但这里我们希望模型是空间可变的,或者更精确的说是对位置敏感的,因此这里分割mask是基于网格的,并且必须被不同的特征通道分开。

因此作者使用了CoordConv(具体介绍见CoordConv(NeurIPS 2018)-CSDN博客),即在网络的开始,直接将归一化的像素坐标送入网络。具体来说就是创建一个和输入空间大小相等的张量,通道数为2,即包含每个像素点归一化的坐标,然后与输入拼接起来送入后续的层。这样假设原始特征tensor的维度为 \(H\times W\times D\),则新的张量维度为 \(H\times W\times (D+2)\),最后两个维度是像素的x-y坐标。

Network Architecture

SOLO前面是一个骨干网络和一个FPN的neck,FPN生成的不同大小的特征作为每个预测head的输入,每个head包括semantic category和instance mask。不同level的head的权重是共享的,网格数量不同,因此只有最后一个卷积层不共享。head的结构如图3所示。

SOLO Learning

Label Assignment 对于类别预测分支,网络需要输出每个网格的目标类别概率。具体来说,网格 \((i,j)\) 如果落入任意一个ground truth mask的中心区域,则这个网格视为一个正样本,否则视为负样本。Center sampling在最近的目标检测模型中常用,这里作者也使用了类似的方法。给定gt mask的重心 \((c_x,c_y)\) 和宽高 \(h,w\),中心区域center region通过一个常量缩放系数来控制 \(\epsilon:(c_x,c_y,\epsilon w,\epsilon h)\),本文设置 \(\epsilon=2\),每个gt mask平均有3个正样本。

除了实例类别的标签外,每个正样本还有一个binary分割mask标签。由于一共有 \(S^2\) 个网格,所以每张图片也有 \(S^2\) 个输出mask。正样本对应的分割mask标签中有前景,负样本没有。

Loss Function 训练损失定义如下

其中 \(L_{cate}\) 是Focal Loss,\(L_{mask}\) 是mask预测的损失

其中 \(i=\left \lfloor k/S \right \rfloor ,j=k\ mod\ S\),对网格按从左到右从上到下的顺序索引。\(N_{pos}\) 表示正样本的数量,\(\mathbf{p}^*,\mathbf{m}*\) 分别表示category和mask目标。\(\mathbb{1}\) 是indicator function,当 \(\mathbf{p}^*_{i,j}>0\) 时值为1否则为0。\(d_{mask}\) 采用Dice Loss,\(\lambda\) 设置为3。

代码解析

这里以mmdetection中的实现为例,输入shape为(2 ,3 736, 1344),2是batch size。Backbone为ResNet-50,经过backbone的输出为[(2, 256, 184, 336), (2, 512, 92, 168), (2, 1024, 46, 84), (2, 2048, 23, 42)],经过FPN neck处理后的输出为[(2, 256, 184, 336), (2, 256, 92, 168), (2, 256, 46, 84), (2, 256, 23, 42), (2, 256, 12, 21)]。FPN五个level的网格个数设置为num_grids=[40, 36, 24, 16, 12]

然后进入到solo_head.py中的forward函数,代码如下,其中加了一些注释

def forward(self, x: Tuple[Tensor]) -> tuple:
    """Forward features from the upstream network.

    Args:
        x (tuple[Tensor]): Features from the upstream network, each is
            a 4D-tensor.

    Returns:
        tuple: A tuple of classification scores and mask prediction.

            - mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
              Each element in the list has shape
              (batch_size, num_grids**2 ,h ,w).
            - mlvl_cls_preds (list[Tensor]): Multi-level scores.
              Each element in the list has shape
              (batch_size, num_classes, num_grids ,num_grids).
    """
    assert len(x) == self.num_levels
    # [(2,256,184,336),
    #  (2,256,92,168),
    #  (2,256,46,84),
    #  (2,256,23,42),
    #  (2,256,12,21)]
    feats = self.resize_feats(x)  # 这里为什么要resize?
    # [(2,256,92,168),
    #  (2,256,92,168),
    #  (2,256,46,84),
    #  (2,256,23,42),
    #  (2,256,23,42)]
    mlvl_mask_preds = []
    mlvl_cls_preds = []
    for i in range(self.num_levels):  # 5
        x = feats[i]
        mask_feat = x
        cls_feat = x
        # generate and concat the coordinate
        coord_feat = generate_coordinate(mask_feat.size(),
                                         mask_feat.device)  # (2,2,92,168)
        mask_feat = torch.cat([mask_feat, coord_feat], 1)  # (2,258,92,168)

        for mask_layer in self.mask_convs:
            mask_feat = mask_layer(mask_feat)  # (2,256,92,168)

        mask_feat = F.interpolate(
            mask_feat, scale_factor=2, mode='bilinear')  # (2,256,184,336), 这里为什么要上采样?
        mask_preds = self.conv_mask_list[i](mask_feat)  # (2,1600,184,336), 这里num_grids的设置依据是什么?

        # cls branch
        for j, cls_layer in enumerate(self.cls_convs):
            if j == self.cls_down_index:  # 0
                num_grid = self.num_grids[i]
                # (2,256,92,168)
                cls_feat = F.interpolate(
                    cls_feat, size=num_grid, mode='bilinear')  # (2,256,40,40)
            cls_feat = cls_layer(cls_feat)  # (2,256,40,40)

        cls_pred = self.conv_cls(cls_feat)  # (2,1,40,40)

        if not self.training:
            feat_wh = feats[0].size()[-2:]
            upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
            mask_preds = F.interpolate(
                mask_preds.sigmoid(), size=upsampled_size, mode='bilinear')
            cls_pred = cls_pred.sigmoid()
            # get local maximum
            local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
            keep_mask = local_max[:, :, :-1, :-1] == cls_pred
            cls_pred = cls_pred * keep_mask

        mlvl_mask_preds.append(mask_preds)
        mlvl_cls_preds.append(cls_pred)
    # [(2,1600,184,336),
    #  (2,1296,184,336),
    #  (2,576,92,168),
    #  (2,256,46,84),
    #  (2,144,46,84)]
    ##########
    # [(2,1,40,40),
    #  (2,1,36,36),
    #  (2,1,24,24),
    #  (2,1,16,16),
    #  (2,1,12,12)]
    return mlvl_mask_preds, mlvl_cls_preds

这里第一步的self.resize_feats不知道是为什么,把最大的特征图下采样2x,把最小的特征上采样2x。接下来遍历每个level的特征图,分别送入head,对各个level的特征图head是共享的。head分为mask分支和cls分支,如图3所示,两个分支前面都包含7个卷积,其中mask分支在输入特征后面concat了坐标map。第44行将mask分支的特征图又上采样了2x也不知道是为什么,最后通过self.conv_mask_list[i]将通道数映射为 \(S^2\)。cls分支是通过插值将spatial维度映射为 \((S,S)\),通过self.conv_cls将通道数映射为类别数。这里得到了模型的最终预测输出,下面就是计算target以及和target之间的loss了。

接下来进入到loss_by_feat函数,其中self._get_targets_single根据gt标签计算损失的target。代码如下

def _get_targets_single(self,
                        gt_instances: InstanceData,
                        featmap_sizes: Optional[list] = None) -> tuple:
    """Compute targets for predictions of single image.

    Args:
        gt_instances (:obj:`InstanceData`): Ground truth of instance
            annotations. It should includes ``bboxes``, ``labels``,
            and ``masks`` attributes.
        featmap_sizes (list[:obj:`torch.size`]): Size of each
            feature map from feature pyramid, each element
            means (feat_h, feat_w). Defaults to None.

    Returns:
        Tuple: Usually returns a tuple containing targets for predictions.

            - mlvl_pos_mask_targets (list[Tensor]): Each element represent
              the binary mask targets for positive points in this
              level, has shape (num_pos, out_h, out_w).
            - mlvl_labels (list[Tensor]): Each element is
              classification labels for all
              points in this level, has shape
              (num_grid, num_grid).
            - mlvl_pos_masks (list[Tensor]): Each element is
              a `BoolTensor` to represent whether the
              corresponding point in single level
              is positive, has shape (num_grid **2).
    """
    gt_labels = gt_instances.labels  # tensor([0, 0], device='cuda:0')
    device = gt_labels.device

    gt_bboxes = gt_instances.bboxes  # tensor([[375.1396, 306.0325, 551.2051, 437.0340], [791.7946, 298.9660, 968.8605, 434.3161]], device='cuda:0')
    gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
                          (gt_bboxes[:, 3] - gt_bboxes[:, 1]))  # tensor([151.8712, 154.8092], device='cuda:0')

    gt_masks = gt_instances.masks.to_tensor(
        dtype=torch.bool, device=device)  # torch.Size([2, 736, 1344])

    mlvl_pos_mask_targets = []
    mlvl_labels = []
    mlvl_pos_masks = []
    for (lower_bound, upper_bound), stride, featmap_size, num_grid \
            in zip(self.scale_ranges, self.strides,
                   featmap_sizes, self.num_grids):

        mask_target = torch.zeros(
            [num_grid**2, featmap_size[0], featmap_size[1]],
            dtype=torch.uint8,
            device=device)  # (40^2, 184, 336)
        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
        labels = torch.zeros([num_grid, num_grid],
                             dtype=torch.int64,
                             device=device) + self.num_classes  # (40,40), 注意这里加了num_classes
        pos_mask = torch.zeros([num_grid**2],
                               dtype=torch.bool,
                               device=device)  # (1600)

        gt_inds = ((gt_areas >= lower_bound) &
                   (gt_areas <= upper_bound)).nonzero().flatten()  # tensor([0, 1], device='cuda:0'),  torch.Size([2])
        if len(gt_inds) == 0:
            mlvl_pos_mask_targets.append(
                mask_target.new_zeros(0, featmap_size[0], featmap_size[1]))
            mlvl_labels.append(labels)
            mlvl_pos_masks.append(pos_mask)
            continue

        hit_gt_bboxes = gt_bboxes[gt_inds]
        hit_gt_labels = gt_labels[gt_inds]
        hit_gt_masks = gt_masks[gt_inds, ...]  # torch.Size([2, 736, 1344])

        pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] -
                              hit_gt_bboxes[:, 0]) * self.pos_scale  # tensor([17.6066, 17.7066], device='cuda:0')
        pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] -
                              hit_gt_bboxes[:, 1]) * self.pos_scale  # tensor([13.1002, 13.5350], device='cuda:0')

        # Make sure hit_gt_masks has a value
        valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0  # tensor([True, True], device='cuda:0')
        output_stride = stride / 2  # 4.0,  为什么要除以2

        # 先遍历level后遍历实例
        for gt_mask, gt_label, pos_h_range, pos_w_range, \
            valid_mask_flag in \
                zip(hit_gt_masks, hit_gt_labels, pos_h_ranges,
                    pos_w_ranges, valid_mask_flags):
            if not valid_mask_flag:
                continue
            upsampled_size = (featmap_sizes[0][0] * 4,
                              featmap_sizes[0][1] * 4)  # 原始输入大小, (736, 1344)
            center_h, center_w = center_of_mass(gt_mask)  # tensor(371.9828, device='cuda:0') tensor(464.5376, device='cuda:0')

            coord_w = int(
                floordiv((center_w / upsampled_size[1]), (1. / num_grid),
                         rounding_mode='trunc'))  # 12
            coord_h = int(
                floordiv((center_h / upsampled_size[0]), (1. / num_grid),
                         rounding_mode='trunc'))  # 18

            # left, top, right, down
            top_box = max(
                0,
                int(
                    floordiv(
                        (center_h - pos_h_range) / upsampled_size[0],
                        (1. / num_grid),
                        rounding_mode='trunc')))
            down_box = min(
                num_grid - 1,
                int(
                    floordiv(
                        (center_h + pos_h_range) / upsampled_size[0],
                        (1. / num_grid),
                        rounding_mode='trunc')))
            left_box = max(
                0,
                int(
                    floordiv(
                        (center_w - pos_w_range) / upsampled_size[1],
                        (1. / num_grid),
                        rounding_mode='trunc')))
            right_box = min(
                num_grid - 1,
                int(
                    floordiv(
                        (center_w + pos_w_range) / upsampled_size[1],
                        (1. / num_grid),
                        rounding_mode='trunc')))
            # 17,18,11,12

            top = max(top_box, coord_h - 1)
            down = min(down_box, coord_h + 1)
            left = max(coord_w - 1, left_box)
            right = min(right_box, coord_w + 1)
            # 17,18,11,12

            labels[top:(down + 1), left:(right + 1)] = gt_label  # center region内的grid都作为正样本
            # 如果两个实例的中心区域有重叠,那这里后一个实例的赋值岂不是会覆盖前一个的结果
            # 只有分类有center sampling,分割没有
            # ins
            gt_mask = np.uint8(gt_mask.cpu().numpy())
            # Follow the original implementation, F.interpolate is
            # different from cv2 and opencv
            gt_mask = mmcv.imrescale(gt_mask, scale=1. / output_stride)  # 1/4
            gt_mask = torch.from_numpy(gt_mask).to(device=device)  # (184, 336)

            for i in range(top, down + 1):
                for j in range(left, right + 1):
                    index = int(i * num_grid + j)
                    mask_target[index, :gt_mask.shape[0], :gt_mask.
                                shape[1]] = gt_mask
                    pos_mask[index] = True
        mlvl_pos_mask_targets.append(mask_target[pos_mask])
        mlvl_labels.append(labels)
        mlvl_pos_masks.append(pos_mask)
    return mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks

这个函数只处理单张图片,其中外层循环遍历level,内层遍历所有实例。首先定义空的target,包括mask_target,shape为 \((S^2,H,W)\);分类的target即labels,shape为 \((S,S)\),这里实际还应该有第三个维度 \(C\) 为类别数,这里没有用one-hot形式,直接放入label的原始索引;此外还定义了一个pos_mask,shape为 \((S^2,)\),表示每个网格是否为正样本,正样本为True,负样本为False。

首先外层遍历每个level,每个level提前设定了负责检测的实例的大小范围,如下 

scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),

只有某个实例的大小在当前level的范围内,这个level才负责检测这个物体,即下面代码

gt_inds = ((gt_areas >= lower_bound) &
                   (gt_areas <= upper_bound)).nonzero().flatten()  # tensor([0, 1], device='cuda:0'),  torch.Size([2])

 接下来pos_w_rangespos_h_ranges表示物体的center region,其中self.pos_scale=0.2就是论文中的 \(\epsilon\)。然后遍历每个实例,通过函数center_of_mass计算物体的重心,coord_wcoord_h表示重心落入的格子分别是宽高方向的第几个,然后top_box, down_box, left_box, right_box表示根据重心落入的网格和目标center region的大小计算正样本网络的索引,下面的代码是根据正样本网格的索引赋值类别target和mask_target以及pos_mask。

for i in range(top, down + 1):
    for j in range(left, right + 1):
        index = int(i * num_grid + j)
        mask_target[index, :gt_mask.shape[0], :gt_mask.
                    shape[1]] = gt_mask
        pos_mask[index] = True

在得到实际学习的target后,就是与前面forward函数的输出即head的输出计算损失,这里不再赘述。

实验结果

表1是在COCO测试集上SOLO和其它实例分割模型的性能对比,可以看到SOLO取得了最优的结果。

  • 35
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值