Stacked Hourglass Network(ECCV'16) 论文笔记+pytorch代码阅读

pytorch 代码:https://github.com/princeton-vl/pytorch_stacked_hourglass

论文原理

Summary

论文设计了一个新的框架Stacked Hourglass Network,通过提取和融合多尺度特征,来更好的捕获人体关键点的各种空间关系(spatial relationship)。

Motivation

Hourglass模块设计的初衷就是为了捕捉每个尺度下的信息,因为捕捉像脸,手这些部分的时候需要局部的特征,而最后对人体姿态进行预测的时候又需要整体的信息。(感受野较大的feature map可以捕捉到更高阶的特征和全局上下文)

  • Hourglass Layer:The network captures and consolidates information across all scales of the image. 通过一步步的下采样获取不同尺度的feature map,再通过上采样和skip layers对不同尺度下的特征进行融合;与其他结构相比有更加对称的结构

  • multiple iterative stages:Hourglass层可以在局部和全局上下文信息中提取feature,并生成预测。之后用多个Hourglass迭代,可以对高阶特征进行多次处理,来进行进一步评估,并重新估计高阶的空间关系(spatial relationship)

  • Intermediate Supervision:在每个stage都计算loss, 避免深度多阶段网络中常见的梯度消失问题。

Model

本论文网络的设计深刻的体现出从 Block 到 Layer 最后形成网络的方法。

Resiual Model

在这里插入图片描述
论文使用Residual模块来提取特征,本文中残差模块不改变输入的高度和高度,仅改变Channel,而且保证输出Channel始终为256。

实际上残差模块是BottleNeck。作者进行一系列的尝试,从卷积核较大的标准卷积到一些新方法,像Residual,Inception模块。

Hourglass

在这里插入图片描述
Hourglass模块由上下两路构成,下路–获得较小尺度的特征
上路:利用Residual提取原尺度的特征,与下路相加后便得到融合的多尺度的特征
上采样:最近邻插值

下路:获得较小尺度的特征
下采样:max pooling (2*2)
第一个Residual:用来获取下采样之后的特征
第二个Residual:可以换成Hourglass层,得到高阶Hourglass
第三个Residual:我猜测是用来保证对称结构,hhhh

多阶Hourglass
多阶Hourglass
论文中使用的四阶Hourglass
原文使用了8个4阶Hourglass。。。
输入 64 × 64 64\times64 64×64的feature map,经过4次下采样后获得 8 × 8 8\times8 8×8的feature map,然后进行4次上采样。在整个过程中,Channel=256。

完整网络

在这里插入图片描述
Stacked Hourglass:最终一共使用了8个沙漏网络。每个沙漏网络的输入都为 64 × 64 64\times64 64×64

输入图片大小为 256 × 256 256\times256 256×256,一开始经过一次 7 × 7 7\times7 7×7 stride=2 的卷积(padding=(kernel_size-1)//2),紧接着跟随一个residual module和 max pooling将像素值从128下降到64(减少hourglass内部计算量)。其中所有的残差模块输出256个特征图。

单个 Hourglass 能够做到提取特征图在不同尺度上的信息,但是仍然不能在预测时(显式地)考虑不同关键点之间的关系。所以才需要进一步设计 Stacked Hourglass。

损失函数

每个关键点的 Ground Truth 定义为以该关键点为峰值位置的 2D 高斯函数。Loss Function 定义为 Ground Truth 与预测得到的 heatmap 之间的均方误差。

Intermediate Supervision

中继监督的思想在更早的网络里就已经被提出了。GoogLeNet V1 就在网络中部和中后部额外设置了全连接层分类支路,和最终的输出一样,这些支路对 Loss 有一定的贡献。当时的想法是,由于深层网络的多次下采样操作,一定尺度上的特征信息会丢失,所以才设计了这些中继监督位点。与 Stacked Hourglass 不同,GoogLeNet 中的这些中继监督位点的输出并没有再返回网络里。
在这里插入图片描述
x输入Hourglass后得到(B, 256,64,64)的输出值,通过用于预测的 1 × 1 1\times1 1×1卷积(蓝框部分)改变Channel,使其与GroundTruth一致,即得到预测(B, 16, 64, 64),该预测可以计算Loss。最后将预测、Hourglass的输出值,以及输入值融合\相加(通过 1 × 1 1\times1 1×1卷积保证Channel一致),送到下一个Hourglass。

直观上,将中继监督的输出重新返回网络中,起到了一种“让网络对当前的特征和预测结果进行再评估”的作用。通过重复地将多个 Hourglass 和中继监督串联,网络能够显式地学习每个预测目标之间的关系,越靠后的预测越能够结合所有关键点的位置信息,做出更准确的关键点位置预测。有了这种特性,网络就基本不会预测出一些从解剖学上不成立的人体姿势。

按元素相加虽然直观上看比较奇怪,但实际上,concatenate到一起之后,再通过一次卷积降维,那次卷积的最后其实也是按元素相加的操作。所以这里直接按元素相加,可以当作之前的卷积层已经得到了可相加的特征,这并没有什么不妥。

GroundTruth
每个关键点的 Ground Truth 定义为以该关键点为峰值位置的 2D 高斯函数。

首先介绍MPII数据集。该数据集主要用于单个人的姿态估计,但它确实为同一图像中的多个人提供关节注释。对于每个人,它给出了16个关节的坐标,比如左脚踝或右肩膀。
在这里插入图片描述
关于GroundTruth的另一件重要的事情是高斯分布。当我们生成GroundTruth的heatmap时,我们不只是为关节坐标分配1,并为所有其他像素分配0。这将使GroundTruth过于稀疏,难以了解。如果模型预测只差几个像素,也是值得鼓励的。
在这里插入图片描述
用高斯函数对关键点处理,使其中心值最大,中心周围区域值逐渐减小。左图是单个关键点的heatmap,右图是把所有16个关节放在一张heatmap中。
在这里插入图片描述

预测

与直接回归相比,使用热图的一个缺点是粒度(granularity)。例如,使用 256 × 256 256\times256 256×256输入,我们将得到一个 64 × 64 64\times64 64×64的热图来表示关键点位置。四倍的缩小比例似乎不是很糟糕。然而,我们通常首先将较大的图像(如 720 × 480 720\times480 720×480)调整 256 × 256 256\times256 256×256输入。在这种情况下,64x64的热图太粗糙了。为了缓解这个问题,研究人员提出了一个有趣的想法。我们不只是使用最大值的像素,我们还考虑了相邻的最大值像素。由于某个相邻像素也很高,因此它推断实际的关键点位置可能是朝向相邻像素的方向。听起来很熟悉,对吧?这很像梯度下降法,它也指向最优解。

消融实验

首先设计了几组网络,来讨论中间监督和stacked Hourglass
在这里插入图片描述
为了探索stacked Hourglass设计的效果,我们必须证明性能的变化是框架的设计,而不是由于更大、更深的网络。

在图9中比较了2堆叠、4堆叠和8堆叠网络的验证精度,它们有相同的参数,都包含中间预测。
在这里插入图片描述

代码

建立模型

基本层

from torch import nn
Pool = nn.MaxPool2d

def batchnorm(x):  
    return nn.BatchNorm2d(x.size()[1])(x)
class Conv(nn.Module):
    """
        卷积层(包含BN和ReLU)
        参数:inp_dim 输入Channel
                 out_dim 输出Channel    
    """
    def __init__(self, inp_dim, out_dim, kernel_size=3, stride = 1, bn = False, relu = True):
        super(Conv, self).__init__()
        self.inp_dim = inp_dim
        self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=True)
        self.relu = None
        self.bn = None
        if relu:
            self.relu = nn.ReLU()
        if bn:
            self.bn = nn.BatchNorm2d(out_dim)

    def forward(self, x):
        assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

Residual Model

class Residual(nn.Module):
    """ 
        Residual层 (实际上是BottleNeck)
        参数:inp_dim 输入Channel
              out_dim 输出Channel
    """
    def __init__(self, inp_dim, out_dim):
        super(Residual, self).__init__()
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(inp_dim)
        self.conv1 = Conv(inp_dim, int(out_dim/2), 1, relu=False)
        self.bn2 = nn.BatchNorm2d(int(out_dim/2))
        self.conv2 = Conv(int(out_dim/2), int(out_dim/2), 3, relu=False)
        self.bn3 = nn.BatchNorm2d(int(out_dim/2))
        self.conv3 = Conv(int(out_dim/2), out_dim, 1, relu=False)
        self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False)
        if inp_dim == out_dim:
            self.need_skip = False
        else:
            self.need_skip = True

    def forward(self, x):
        if self.need_skip:
            residual = self.skip_layer(x)
        else:
            residual = x
        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)
        out += residual
        return out 

Hourglass Model


class Hourglass(nn.Module):
    """
        Residual层不改变输入feature的W、H
        下采样主要由 pool 完成
    """
    def __init__(self, n, f, bn=None, increase=0):
        super(Hourglass, self).__init__()
        nf = f + increase
        self.up1 = Residual(f, f)
        # Lower branch
        self.pool1 = Pool(2, 2)
        self.low1 = Residual(f, nf)
        self.n = n
        # Recursive hourglass
        if self.n > 1:
            self.low2 = Hourglass(n-1, nf, bn=bn)
        else:
            self.low2 = Residual(nf, nf)
        self.low3 = Residual(nf, f)
        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x):
        up1  = self.up1(x)
        pool1 = self.pool1(x)
        low1 = self.low1(pool1)
        low2 = self.low2(low1)
        low3 = self.low3(low2)
        up2  = self.up2(low3)
        return up1 + up2

完整网络

class UnFlatten(nn.Module):
    def forward(self, input):
        return input.view(-1, 256, 4, 4)

class Merge(nn.Module):
    def __init__(self, x_dim, y_dim):
        super(Merge, self).__init__()
        self.conv = Conv(x_dim, y_dim, 1, relu=False, bn=False)

class PoseNet(nn.Module):
    """
        'nstack': 8,
        'inp_dim': 256,
        'oup_dim': 16,
        'num_parts': 16,
        'increase': 0,
    """
    def __init__(self, nstack, inp_dim, oup_dim, bn=False, increase=0, **kwargs):
        super(PoseNet, self).__init__()
        
        self.nstack = nstack
        self.pre = nn.Sequential(
            Conv(3, 64, 7, 2, bn=True, relu=True),
            Residual(64, 128),
            Pool(2, 2),
            Residual(128, 128),
            Residual(128, inp_dim)
        )
        
        self.hgs = nn.ModuleList( [
        nn.Sequential(
            Hourglass(4, inp_dim, bn, increase),
        ) for i in range(nstack)] )       
        self.features = nn.ModuleList( [
        
        nn.Sequential(
            Residual(inp_dim, inp_dim),
            Conv(inp_dim, inp_dim, 1, bn=True, relu=True)
        ) for i in range(nstack)] )       
        self.outs = nn.ModuleList( [Conv(inp_dim, oup_dim, 1, relu=False, bn=False) for i in range(nstack)] )
        self.merge_features = nn.ModuleList( [Merge(inp_dim, inp_dim) for i in range(nstack-1)] )
        self.merge_preds = nn.ModuleList( [Merge(oup_dim, inp_dim) for i in range(nstack-1)] )
        self.nstack = nstack
        self.heatmapLoss = HeatmapLoss()

    def forward(self, imgs):
        ## our posenet
        x = imgs.permute(0, 3, 1, 2) #x of size 1,3,inpdim,inpdim
        x = self.pre(x)                      # x :(B, 256, 64, 64)
        combined_hm_preds = []    # (i, B, 16, 64, 64 )
        for i in range(self.nstack):
            hg = self.hgs[i](x)                 #(B, 256, 64, 64)
            feature = self.features[i](hg) #(B, 256, 64, 64)
            preds = self.outs[i](feature)  #(B, 256, 16, 16)
            combined_hm_preds.append(preds)
            if i < self.nstack - 1:
                x = x + self.merge_preds[i](preds) + self.merge_features[i](feature)   # 将 估计 与 特征 以及 x 融合
        return torch.stack(combined_hm_preds, 1) # 形成新的数组  (B, i,16, 64, 64 )

    def calc_loss(self, combined_hm_preds, heatmaps):
        combined_loss = []
        for i in range(self.nstack):
            combined_loss.append(self.heatmapLoss(combined_hm_preds[0][:,i], heatmaps)) 
        combined_loss = torch.stack(combined_loss, dim=1)
        return combined_loss     

损失函数

import torch
class HeatmapLoss(torch.nn.Module):
    """
    loss for detection heatmap
    """
    def __init__(self):
        super(HeatmapLoss, self).__init__()

    def forward(self, pred, gt):
        l = ((pred - gt)**2)
        l = l.mean(dim=3).mean(dim=2).mean(dim=1)
        return l ## l of dim bsize: torch.size([Bsize])
        

参考博客
https://blog.csdn.net/shenxiaolu1984/article/details/51428392
https://blog.csdn.net/u013841196/article/details/81048237

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值