PPM 金字塔池化模块 - PSPNet

原理浅析

  金字塔池化模块(Pyramid Pooling Module,PPM)于 2017 年提出,论文相关信息如下:

论文:《Pyramid Scene Parsing Network
作者:Hengshuang Zhao et al.(香港中文大学 & 商汤科技)
来源:CVPR 2017

  PPM 提出的目的,是为了聚合不同区域的上下文信息,以提高网络获取全局信息的能力。具体做法为:在原始特征图上使用不同尺度的池化,得到多个不同尺寸的特征图,再在通道维度上拼接这些特征图 (含原始特征图),最终输出一个糅合了多种尺度的复合特征图,从而达到兼顾全局语义信息与局部细节信息的目的。PSPNet 网络结构如下:

( a ) (a) (a) 输入图片;
( b ) (b) (b) 通过 CNN 提取的原始特征图 ( 6 × 6 6 \times 6 6×6);
( c ) (c) (c) PPM 模块:对原始特征图进行不同尺度的池化操作,得到多个不同尺寸的特征图(图中为 4 个)。对得到的特征图进行上采样操作,恢复至原始特征图大小 ( 6 × 6 6 \times 6 6×6),最后在通道维度上进行拼接,得到最终的复合特征图;

红:使用 ( 6 × 6 6 \times 6 6×6) 的池化,输出尺寸为 ( 1 × 1 1 \times 1 1×1) ,再通过双线性插值上采样至 ( 6 × 6 6 \times 6 6×6);
橙:使用 ( 3 × 3 3 \times 3 3×3) 的池化,输出尺寸为 ( 2 × 2 2 \times 2 2×2) ,再通过双线性插值上采样至 ( 6 × 6 6 \times 6 6×6);
蓝:使用 ( 2 × 2 2 \times 2 2×2) 的池化,输出尺寸为 ( 3 × 3 3 \times 3 3×3) ,再通过双线性插值上采样至 ( 6 × 6 6 \times 6 6×6);
绿:使用 ( 1 × 1 1 \times 1 1×1) 的池化,输出尺寸为 ( 6 × 6 6 \times 6 6×6) 。

( d ) (d) (d) 通过末层卷积实现场景解析,即像素级别的分类。

代码实现 - pytorch
# _*_coding:utf-8_*_
import torch
import torch.nn as nn
import torch.nn.functional as F


class PPM(nn.Module):
    def __init__(self, in_dim, out_dim, bins):
        super(PPM, self).__init__()
        self.features = []
        for bin in bins:
            self.features.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(bin),
                nn.Conv2d(in_dim, out_dim, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_dim),
                nn.ReLU(inplace=True)
            ))
        self.features = nn.ModuleList(self.features)

    def forward(self, x):
        x_size = x.size()
        out = [x]
        for f in self.features:
            temp = f(x)
            temp = F.interpolate(temp, x_size[2:], mode="bilinear", align_corners=True)
            out.append(temp)

        return torch.cat(out, 1)


if __name__ == "__main__":
    # inputs: (B, C, H, W)
    inputs = torch.rand((8, 3, 16, 16))
    # PPM params: (in_dim, out_dim, sizeList)
    ppm = PPM(3, 2, [1, 2, 3, 6])
    # outputs: (B=8, C=3+2*4=11, H=16, W=16)
    outputs = ppm(inputs)
    print("Outputs shape:", outputs.size())

参考

  1. PSP模块Tensorflow/Pytorch实现小结
  • 17
    点赞
  • 106
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值