Cneternet原理详解(object as points)

ariXiv:https://arxiv.org/abs/1904.07850
github:https://github.com/xingyizhou/CenterNet.

1.网络的整体结构:

在这里插入图片描述

2.网络的特征提取:语义分割的编码解码全卷积网络(DLA和Hourglass)

2.1Hourglass
Hourglass网络结构最初是在ECCV2016的Stacked hourglass networks for human pose estimation文章中提出的,用于人体姿态估计。Stacked Hourglass就是把多个漏斗形状的网络级联起来,可以获取多尺度的信息。
Hourglass的设计比较有层次,通过各个模块的有规律组合成完整网络。
2.1.1 Hourglass网络的最底层:Residual模块

class residual(nn.Module):
    def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True):
        super(residual, self).__init__()

        self.conv1 = nn.Conv2d(inp_dim,
                               out_dim, (3, 3),
                               padding=(1, 1),
                               stride=(stride, stride),
                               bias=False)
        self.bn1 = nn.BatchNorm2d(out_dim)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_dim,
                               out_dim, (3, 3),
                               padding=(1, 1),
                               bias=False)
        self.bn2 = nn.BatchNorm2d(out_dim)

        self.skip = nn.Sequential(nn.Conv2d(inp_dim, out_dim, (1, 1), stride=(stride, stride), bias=False),
                                  nn.BatchNorm2d(out_dim)) \
            if stride != 1 or inp_dim != out_dim else nn.Sequential()
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        conv1 = self.conv1(x)
        bn1 = self.bn1(conv1)
        relu1 = self.relu1(bn1)

        conv2 = self.conv2(relu1)
        bn2 = self.bn2(conv2)

        skip = self.skip(x)
        return self.relu(bn2 + skip)

2.1.2 Hourglass子模块

class kp_module(nn.Module):
    '''
    kp module指的是hourglass基本模块
    '''
    def __init__(self, n, dims, modules):
        super(kp_module, self).__init__()

        self.n = n

        curr_modules = modules[0]
        next_modules = modules[1]

        curr_dim = dims[0]
        next_dim = dims[1]

        # curr_mod x residual,curr_dim -> curr_dim -> ... -> curr_dim
        self.top = make_layer(3, # 空间分辨率不变
                              curr_dim,
                              curr_dim,
                              curr_modules,
                              layer=residual)
        self.down = nn.Sequential() # 暂时没用
        # curr_mod x residual,curr_dim -> next_dim -> ... -> next_dim
        self.low1 = make_layer(3,
                               curr_dim,
                               next_dim,
                               curr_modules,
                               layer=residual,
                               stride=2)# 降维
        # next_mod x residual,next_dim -> next_dim -> ... -> next_dim
        if self.n > 1:
            # 通过递归完成构建
            self.low2 = kp_module(n - 1, dims[1:], modules[1:])
        else:
            # 递归出口
            self.low2 = make_layer(3,
                                   next_dim,
                                   next_dim,
                                   next_modules,
                                   layer=residual)
        # curr_mod x residual,next_dim -> next_dim -> ... -> next_dim -> curr_dim
        self.low3 = make_layer_revr(3, # 升维
                                    next_dim,
                                    curr_dim,
                                    curr_modules,
                                    layer=residual)
        self.up = nn.Upsample(scale_factor=2) # 上采样进行升维

    def forward(self, x):
        up1 = self.top(x)
        down = self.down(x)
        low1 = self.low1(down)
        low2 = self.low2(low1)
        low3 = self.low3(low2)
        up2 = self.up(low3)
        return up1 + up2

其中有两个主要的函数make_layer和make_layer_revr,make_layer将空间分辨率降维,make_layer_revr函数进行升维,所以将这个结构命名为hourglass(沙漏)。
核心构建是一个递归函数,递归层数是通过n来控制,称之为n阶hourglass模块。
在这里插入图片描述
2.1.3 Hourglass整体模型

class exkp(nn.Module):
    '''
     整体模型调用
     large hourglass stack为2
     small hourglass stack为1
     n这里控制的是hourglass的阶数,以上两个都用的是5阶的hourglass
     exkp(n=5, nstack=2, dims=[256, 256, 384, 384, 384, 512], modules=[2, 2, 2, 2, 2, 4]),
    '''
    def __init__(self, n, nstack, dims, modules, cnv_dim=256, num_classes=80):
        super(exkp, self).__init__()

        self.nstack = nstack # 堆叠多次hourglass
        self.num_classes = num_classes

        curr_dim = dims[0]

        # 快速降维为原来的1/4
        self.pre = nn.Sequential(convolution(7, 3, 128, stride=2),
                                 residual(3, 128, curr_dim, stride=2))

        # 堆叠nstack个hourglass
        self.kps = nn.ModuleList(
            [kp_module(n, dims, modules) for _ in range(nstack)])

        self.cnvs = nn.ModuleList(
            [convolution(3, curr_dim, cnv_dim) for _ in range(nstack)])

        self.inters = nn.ModuleList(
            [residual(3, curr_dim, curr_dim) for _ in range(nstack - 1)])

        self.inters_ = nn.ModuleList([
            nn.Sequential(nn.Conv2d(curr_dim, curr_dim, (1, 1), bias=False),
                          nn.BatchNorm2d(curr_dim)) for _ in range(nstack - 1)
        ])
        self.cnvs_ = nn.ModuleList([
            nn.Sequential(nn.Conv2d(cnv_dim, curr_dim, (1, 1), bias=False),
                          nn.BatchNorm2d(curr_dim)) for _ in range(nstack - 1)
        ])
        # heatmap layers
        self.hmap = nn.ModuleList([
            make_kp_layer(cnv_dim, curr_dim, num_classes) # heatmap输出通道为num_classes
            for _ in range(nstack)
        ])
        for hmap in self.hmap:
            # -2.19是focal loss中的默认参数,论文的4.1节有详细说明,-ln((1-pi)/pi),这里的pi取0.1
            hmap[-1].bias.data.fill_(-2.19)

        # regression layers
        self.regs = nn.ModuleList(
            [make_kp_layer(cnv_dim, curr_dim, 2) for _ in range(nstack)]) # 回归的输出通道为2
        self.w_h_ = nn.ModuleList(
            [make_kp_layer(cnv_dim, curr_dim, 2) for _ in range(nstack)]) # wh

        self.relu = nn.ReLU(inplace=True)

    def forward(self, image):
        inter = self.pre(image)

        outs = []
        for ind in range(self.nstack): # 堆叠两次hourglass
            kp = self.kps[ind](inter)
            cnv = self.cnvs[ind](kp)

            if self.training or ind == self.nstack - 1:
                outs.append([
                    self.hmap[ind](cnv), self.regs[ind](cnv),
                    self.w_h_[ind](cnv)
                ])

            if ind < self.nstack - 1:
                inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv)
                inter = self.relu(inter)
                inter = self.inters[ind](inter)
        return outs

这里需要注意的是inters变量,这个变量保存的是中间监督过程,可以在这个位置添加loss,具体如下图蓝色部分所示,在这个部分可以添加loss,然后再用1x1卷积重新映射到对应的通道个数并相加。

在这里插入图片描述
参考:https://zhuanlan.zhihu.com/p/165563550

3.损失函数总共分为三部分:

在这里插入图片描述
其中:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
对应的程序:

def _neg_loss(preds, gt):
    pos_inds = gt.eq(1)
    neg_inds = gt.lt(1)

    neg_weights = torch.pow(1 - gt[neg_inds], 4)

    loss = 0
    for pred in preds:
        pos_pred = pred[pos_inds]
        neg_pred = pred[neg_inds]

        pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2)
        neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) * neg_weights

        num_pos  = pos_inds.float().sum()
        pos_loss = pos_loss.sum()
        neg_loss = neg_loss.sum()

        if pos_pred.nelement() == 0:
            loss = loss - neg_loss
        else:
            loss = loss - (pos_loss + neg_loss) / num_pos
    return loss

def _sigmoid(x):
    x = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
    return x

def _ae_loss(tag0, tag1, mask):
    num  = mask.sum(dim=1, keepdim=True).float()
    tag0 = tag0.squeeze()
    tag1 = tag1.squeeze()

    tag_mean = (tag0 + tag1) / 2

    tag0 = torch.pow(tag0 - tag_mean, 2) / (num + 1e-4)
    tag0 = tag0[mask].sum()
    tag1 = torch.pow(tag1 - tag_mean, 2) / (num + 1e-4)
    tag1 = tag1[mask].sum()
    pull = tag0 + tag1

    mask = mask.unsqueeze(1) + mask.unsqueeze(2)
    mask = mask.eq(2)
    num  = num.unsqueeze(2)
    num2 = (num - 1) * num
    dist = tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2)
    dist = 1 - torch.abs(dist)
    dist = nn.functional.relu(dist, inplace=True)
    dist = dist - 1 / (num + 1e-4)
    dist = dist / (num2 + 1e-4)
    dist = dist[mask]
    push = dist.sum()
    return pull, push

def _regr_loss(regr, gt_regr, mask):
    num  = mask.float().sum()
    mask = mask.unsqueeze(2).expand_as(gt_regr)

    regr    = regr[mask]
    gt_regr = gt_regr[mask]
    
    regr_loss = nn.functional.smooth_l1_loss(regr, gt_regr, size_average=False)
    regr_loss = regr_loss / (num + 1e-4)
    return regr_loss
class AELoss(nn.Module):
    def __init__(self, pull_weight=1, push_weight=1, regr_weight=1, focal_loss=_neg_loss):
        super(AELoss, self).__init__()

        self.pull_weight = pull_weight
        self.push_weight = push_weight
        self.regr_weight = regr_weight
        self.focal_loss  = focal_loss
        self.ae_loss     = _ae_loss
        self.regr_loss   = _regr_loss

    def forward(self, outs, targets):
        stride = 8

        tl_heats = outs[0::stride]
        br_heats = outs[1::stride]
        ct_heats = outs[2::stride]
        tl_tags  = outs[3::stride]
        br_tags  = outs[4::stride]
        tl_regrs = outs[5::stride]
        br_regrs = outs[6::stride]
        ct_regrs = outs[7::stride]

        gt_tl_heat = targets[0]
        gt_br_heat = targets[1]
        gt_ct_heat = targets[2]
        gt_mask    = targets[3]
        gt_tl_regr = targets[4]
        gt_br_regr = targets[5]
        gt_ct_regr = targets[6]
        
        # focal loss
        focal_loss = 0

        tl_heats = [_sigmoid(t) for t in tl_heats]
        br_heats = [_sigmoid(b) for b in br_heats]
        ct_heats = [_sigmoid(c) for c in ct_heats]

        focal_loss += self.focal_loss(tl_heats, gt_tl_heat)
        focal_loss += self.focal_loss(br_heats, gt_br_heat)
        focal_loss += self.focal_loss(ct_heats, gt_ct_heat)

        # tag loss
        pull_loss = 0
        push_loss = 0

        for tl_tag, br_tag in zip(tl_tags, br_tags):
            pull, push = self.ae_loss(tl_tag, br_tag, gt_mask)
            pull_loss += pull
            push_loss += push
        pull_loss = self.pull_weight * pull_loss
        push_loss = self.push_weight * push_loss

        regr_loss = 0
        for tl_regr, br_regr, ct_regr in zip(tl_regrs, br_regrs, ct_regrs):
            regr_loss += self.regr_loss(tl_regr, gt_tl_regr, gt_mask)
            regr_loss += self.regr_loss(br_regr, gt_br_regr, gt_mask)
            regr_loss += self.regr_loss(ct_regr, gt_ct_regr, gt_mask)
        regr_loss = self.regr_weight * regr_loss

        loss = (focal_loss + pull_loss + push_loss + regr_loss) / len(tl_heats)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值