Hourglass网络的理解和代码分析

1. 前言

目标检测最新的一些方法中都采用了人脸关键点检测hourglass网络作为检测的主干网络,如CornerNet-Lite系列。目标检测将anchor生成box的方式替换关键点预测,采用hourglass网络的优点在于物体特征点可能出现在网络的不同层,如果采用VGG,或者ResNet网络,最后的特征图很难包含检测物体的所有关键点。
在这里插入图片描述

2.一个简单的hourglass网络结构

堆叠hourglass网络的结构如下图,每一个白色的box表示一个residual模块。

在这里插入图片描述堆叠hourglass 网络是个递归的结构,输入从左到中间,维度增加,特征map的大小变小,从中间到右,维度减少,特征map变大,即C1嵌套C2,C2嵌套C3,依次类推,,,C5,C6, C7是residual模块串联,总网络是4层的嵌套。

直观一点,沙漏网络:哈哈哈!!!
在这里插入图片描述

residual模块参照下图,这个模块的特性可以对特征图升维和降维,并且不改变特征图的size(W, H)。
每一个白色的box大致可以理解为下面的模块。
在这里插入图片描述
如果上图还是看起来复杂,看下图:(就是个利用1*1卷积对图像升维或者降维)
参考:resnet网络
在这里插入图片描述

3. Hourglass 网络的定义:

3.1 方法一:

1.先定义residual模块:
residual传入参数(输入特征图数,输出特征图数)

class residual(nn.Module):
    def __init__(self, inp_dim, out_dim, k=3, stride=1):
        super(residual, self).__init__()
        p = (k - 1) // 2

        self.conv1 = nn.Conv2d(inp_dim, out_dim, (k, k), padding=(p, p), stride=(stride, stride), bias=False)
        self.bn1   = nn.BatchNorm2d(out_dim)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_dim, out_dim, (k, k), padding=(p, p), bias=False)
        self.bn2   = nn.BatchNorm2d(out_dim)
        
        self.skip  = nn.Sequential(
            nn.Conv2d(inp_dim, out_dim, (1, 1), stride=(stride, stride), bias=False),
            nn.BatchNorm2d(out_dim)
        ) if stride != 1 or inp_dim != out_dim else nn.Sequential()
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        conv1 = self.conv1(x)
        bn1   = self.bn1(conv1)
        relu1 = self.relu1(bn1)

        conv2 = self.conv2(relu1
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值