[SOLO ]SOLO: Segmenting Objects by Locations代码解读笔记(ECCV. 2020)

本文深入解析SOLO对象分割算法,重点探讨SOLO头网络结构、损失函数以及正样本选取策略。通过阅读代码,作者阐述了如何以目标质心为中心进行正样本采样,以及FCOS和PolarMask中的中心采样结构。同时,介绍了SOLO的实现细节,包括在single_stage_ins.py、solo_head.py等模块中的关键操作。
摘要由CSDN通过智能技术生成
Segmenting Objects by Locations

如果对你帮助的话,希望给我个赞~

SOLO head网络结构

在这里插入图片描述

损失函数

在这里插入图片描述

正样本的选取

论文原话:
在这里插入图片描述
起初看完后,并不是很理解。但我认为看完代码后,是我对于正样本选取的一个新的领悟与体会,如何与全卷积网络结合,很好的一个实践与理论相结合,通过代码来反思与加深与论文思想的理解。
其中FCOS、polarmask也是采用了一种中心采样的结构。这些文中都有提到,全卷积网络可以采用gt_box内的所有点为positive example,但是这样子计算量肯定很大,并且其他靠近bbox的点回归的效果肯定是很差的,因此围绕质心(solo以质心为中心)进行正样本采样是非常合理的。
引用一篇特别棒的转载博客里的图片:博客链接
如图所示,在原图中,蓝色框表示图片等分的格子,这里设置分为5X5个格子。绿色框为目标物体的gt box,黄色框表示缩小到0.2倍数的box,红色框表示负责预测该实例的格子。
下方黑白图为mask分支的target可视化,为了便于显示,这里对不同通道进行了拼接。左边的第一幅图,图中有一个实例,其gt box缩小到0.2倍占据两个格子,因此这两个格子负责预测该实例。
下方的mask分支,只有两个FPN的输出匹配到了该实例,因此在红色格子对应的channel负责预测该实例的mask。第二幅图,图中分布大小不同的实例,可见在FPN输出的mask分支上,从小到大负责不同尺度的实例。
在这里插入图片描述

下图是原图的,也很清晰的表达了FPN如何根据不同的gt_areas 以及 实例所处在的网格位置放入对于的channel上预测。首先根据gt_areas将不同的gt放入不同的FPN层。然后再相同层中,如果有多个实例,就会根据设置好的网格,按照某个GT的质心的0.2 * gt_areas(这时候的gt_areas缩小到对应的FPN层输出的feature map的大小)的大小缩放。
在这里插入图片描述

1. SOLO/mmdect/models/detectors/single_stage_ins.py

single_stage_ins中实现了backbone(resnet),neck(fpn)以及head(solo_head)的连接以及forward。

import torch.nn as nn

from mmdet.core import bbox2result
from .. import builder
from ..registry import DETECTORS
from .base import BaseDetector
import pdb

@DETECTORS.register_module
class SingleStageInsDetector(BaseDetector):

    def __init__(self,
                 backbone,
                 neck=None,
                 bbox_head=None,
                 mask_feat_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(SingleStageInsDetector, self).__init__()
        self.backbone = builder.build_backbone(backbone) # 1.build_backbone --> resnet
        if neck is not None:
            self.neck = builder.build_neck(neck) # 2.build_neck --> fpn
        if mask_feat_head is not None:
            self.mask_feat_head = builder.build_head(mask_feat_head)
        #pdb.set_trace()

        self.bbox_head = builder.build_head(bbox_head) # 3.build_head --> solo head

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg
        self.init_weights(pretrained=pretrained) # 'torchvision://resnet50'

    def init_weights(self, pretrained=None):
        super(SingleStageInsDetector, self).init_weights(pretrained)
        self.backbone.init_weights(pretrained=pretrained)
        if self.with_neck:
            if isinstance(self.neck, nn.Sequential):
                for m in self.neck:
                    m.init_weights()
            else:
                self.neck.init_weights()
        if self.with_mask_feat_head:
            if isinstance(self.mask_feat_head, nn.Sequential):
                for m in self.mask_feat_head:
                    m.init_weights()
            else:
                self.mask_feat_head.init_weights()
        #pdb.set_trace()
        self.bbox_head.init_weights()

    # forward提取 backbone 和 neck的特征 
    def extract_feat(self, img):
        x = self.backbone(img) # resnet forward        
        if self.with_neck:
            x = self.neck(x) # fpn forward
        return x
    '''
    after neck feature map:x
        (Pdb) x[0].shape
        torch.Size([2, 256, 200, 304])
        (Pdb) x[1].shape
        torch.Size([2, 256, 100, 152])
        (Pdb) x[2].shape
        torch.Size([2, 256, 50, 76])
        (Pdb) x[3].shape
        torch.Size([2, 256, 25, 38])
        (Pdb) x[4].shape
        torch.Size([2, 256, 13, 19])

    '''
    def forward_dummy(self, img):
        x = self.extract_feat(img)
        outs = self.bbox_head(x)
        return outs

    def forward_train(self,
                      img,
                      img_metas,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None,
                      gt_masks=None):
        # 1. img 
            # eg. [torch.Size([2, 3, 800, 1216]) represents the max size of h and w in the img batch_size

        # 2. img_metas
            # eg.
            #[
            # {'filename': 'data/coco2017/train2017/000000559012.jpg', 
            #   'ori_shape': (508, 640, 3), 
            #   'img_shape': (800, 1008, 3), 
            #   'pad_shape': (800, 1216, 3), 
            #   'scale_factor': 1.8823529411764706, 
            #   'flip': False, 
            #   'img_norm_cfg': {'mean': array([123.675, 116.28 , 103.53 ], dtype=float32), 
            #   'std': array([58.395, 57.12 , 57.375], dtype=float32), 
            #   'to_rgb': True}}, 
            #
            # {'filename': 'data/coco2017/train2017/000000532426.jpg', 
            #   'ori_shape': (333, 640, 3), 'img_shape': (753, 1333, 3), 
            #   'pad_shape': (800, 1088, 3), 'scale_factor': 2.4024024024024024,
            #   'flip': False, 
            #   'img_norm_cfg': {'mean': array([123.675, 116.28 , 103.53 ], dtype=float32), 
            #   'std': array([58.395, 57.12 , 57.375], dtype=float32), 
            #   'to_rgb': True}}
            # ]
  
        # 3. gt_bboxes
            # eg.
            # gt_bboxes represents  'bbox' of coco datasets
            # type(gt_bboxes) --> list 
            # len(gt_bboxes) --> batch_size(ie. img per gpu) eg. 2
            # type(gt_bboxes[idx]) --> tensor
            # gt_bboxes[idx].size() --> [instances, 4]  '4' represents [x1, y1, x2, y2]
            # [6, 4] [9, 4]


        # 4. gt_labels
            # eg.
            # gt_labels represents 'category_id' of coco datasets
            # type(gt_labels) --> list 
            # len(gt_labels) --> batch_size(img per gpu) eg. 2
            # type(gt_labels[idx]) --> tensor
            # gt_labels[idx].size() --> instances eg. how many categories  gt_bboxes[7 or 13, 4] --> gt_labels[7 or 13]
            # 6 , 9

        # 5. gt_masks
            # eg.
            # type(gt_masks) --> list
            # len(gt_masks) --> batch_size(img per gpu) eg. 2
            # type(gt_bboxes[idx]) --> list
            # (6, 800, 1216)  (9, 800, 1088) represents (instances of pad_shape, w, h)


        x = self.extract_feat(img) #    forward backbone and  fpn
        # solo_head forward
        outs = self.bbox_head(x) # forward solo_head
        # outs eg. 各五层
        # 1.ins_pred:
        # outs[0][0].size() --> torch.Size([2, 1600, 200, 336])
        # outs[0][1].size() --> torch.Size([2, 1296, 200, 336]) 
        # outs[0][2].size() --> torch.Size([2, 1024, 100, 168])
        # outs[0][3].size() --> torch.Size([2, 256, 50, 84])
        # outs[0][4].size() --> torch.Size([2, 144, 50, 84])
        # 

        # 2.cate_pred:
        # outs[1][0].size() --> torch.Size([2, 80, 40, 40])
        # outs[1][1].size() --> torch.Size([2, 80, 36, 36])
        # outs[1][2].size() --> torch.Size([2, 80, 24, 24])
        # outs[1][3].size() --> torch.Size([2, 80, 24, 24])
        # outs[1][4].size() --> torch.Size([2, 80, 12, 12])
        # 

        if self.with_mask_feat_head:
            mask_feat_pred = self.mask_feat_head(
                x[self.mask_feat_head.
                  start_level:self.mask_feat_head.end_level + 1])
            loss_inputs = outs + (mask_feat_pred, gt_bboxes, gt_labels, gt_masks, img_metas, self.train_cfg)
        else:
            loss_inputs = outs + (gt_bboxes, gt_labels, gt_masks, img_metas, self.train_cfg) 
            # tuple len(outs) = 2  len(loss_inputs) = 7

        # compute SOLO loss
        losses = self.bbox_head.loss(
            *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        return losses

    def simple_test(self, img, img_meta, rescale=False):
        x = self.extract_feat(img)
        outs = self.bbox_head(x, eval=True) # when testing , eval = True rescale=True
        if self.with_mask_feat_head: # False
            mask_feat_pred = self.mask_feat_head(
                x[self.mask_feat_head.
                  start_level:self.mask_feat_head.end_level + 1])
            seg_inputs = outs + (mask_feat_pred, img_meta, self.test_cfg, rescale)
        else:
            seg_inputs = outs + (img_meta, self.test_cfg, rescale) # forward backbone fpn and solo_head 
        seg_result = self.bbox_head.get_seg(*seg_inputs) # get_seg()
        return seg_result  

    def aug_test(self, imgs, img_metas, rescale=False):
        raise NotImplementedError

2. SOLO/mmdet/models/anchor_heads/solo_head.py

注:一次输入的数据打印在最下方。

import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.ops import DeformConv, roi_align
from mmdet.core import multi_apply, bbox2roi, matrix_nms
from ..builder import build_loss
from ..registry import HEADS
from ..utils import bias_init_with_prob, ConvModule
import pdb
import math
INF = 1e8

from scipy import ndimage

def points_nms(heat, kernel=2):
    # kernel must be 2
    hmax = nn.functional.max_pool2d(
        heat, (kernel, kernel), stride=1, padding=1)
    keep = (hmax[:, :, :-1, :-1] == heat).float() # 在tensor相等(a==b) 是返回一个bool类型的矩阵,T or F; 如果加上float(),则返回1 or 0。 可以使用(hmax[:, :, :-1, :-1] == heat).bool()修正回去。
    return heat * keep # 通过max_pool2d操作后, 返回一个 2*2 中只有一个值非0

def dice_loss(input, target):
    input = input.contiguous().view(input.size()[0], -1) # [instances , w * h]
    target = target.contiguous().view(target.size()[0], -1).float() # [instances , w * h]

    a = torch.sum(input * target, 1)
    b = torch.sum(input * input, 1) + 0.001
    c = torch.sum(target * target, 1) + 0.001
    e = (2 * a) / (b + c)
    print('dice_loss:', 1-e)
    #pdb.set_trace() # [24]
    return 1-e

@HEADS.register_module
class SOLOHead(nn.Module):

    def __init__(self,
                 num_classes,
                 in_channels,
                 seg_feat_channels=256,
                 stacked_convs=4,
                 strides=(4, 8, 16, 32, 64),
                 base_edge_list=(16, 32, 64, 128, 256),
                 scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
                 sigma=0.4,
                 num_grids=None,
                 cate_down_pos=0,
                 with_deform=False,
                 loss_ins=None,
                 loss_cate=None,
                 conv_cfg=None,
                 norm_cfg=None):
        super(SOLOHead, self).__init__()
        self.num_classes = num_classes # 81
        self.seg_num_grids = num_grids # [40, 36, 24, 16, 12]
        self.cate_out_channels = self.num_classes - 1 # 80
        self.in_channels = in_channels #256
        self.seg_feat_channels = seg_feat_channels # 256
        self.stacked_convs = stacked_convs # 7
        self.strides = strides # [8, 8, 16, 32, 32]
        self.sigma = sigma # 0.2
        self.cate_down_pos = cate_down_pos # 0
        self.base_edge_list = base_edge_list # (16, 32, 64, 128, 256)
        self.scale_ranges = scale_ranges # ((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048))
        self.with_deform = with_deform #False
        #loss_cate: {'type': 'FocalLoss', 'use_sigmoid': True, 'gamma': 2.0, 'alpha': 0.25, 'loss_weight': 1.0}

        self.loss_cate = build_loss(loss_cate) # FocalLoss() <class 'mmdet.models.losses.focal_loss.FocalLoss'>
        self.ins_loss_weight = loss_ins['loss_weight'] # 3
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self._init_layers()
        #pdb.set_trace()

    # init  ins_convs, cate_convs, solo_ins_list, solo_cate
    def _init_layers(self):
        norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
        self.ins_convs = nn.ModuleList()
        self.cate_convs = nn.ModuleList()
        for i in range(self.stacked_convs):
            # coorconv要加x y 2维
            chn = self.in_channels + 2 if i == 0 else self.seg_feat_channels
            self.ins_convs.append(
                ConvModule(
                    chn,
                    self.seg_feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    norm_cfg=norm_cfg,
                    bias=norm_cfg is None))

            chn = self.in_channels if i == 0 else self.seg_feat_channels
            self.cate_convs.append(
                ConvModule(
                    chn,
                    self.seg_feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    norm_cfg=norm_cfg,
                    bias=norm_cfg is None))

        self.solo_ins_list = nn.ModuleList()

        # 修改 [h, w, 256] --> [h, w, min(h/s, w/s)^2]   
        self.solo_sa_module = nn.ModuleList()


        # [h, w , 256] ---> [h, w, s*s]

        # 修改
        '''
        for seg_num_grid in self.seg_num_grids:
            self.solo_ins_list.append(
                nn.Conv2d(
                    self.seg_feat_channels, seg_num_grid**2, 1))
        '''
        for seg_num_grid in self.seg_num_grids:
            self.solo_ins_list.append(
                nn.Conv2d(
                seg_num_grid**2, seg_num_grid**2, 1))
      
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值