Pyramid pooling module

Pyramid pooling 方法出自 2017CVPR,原文地址https://arxiv.org/pdf/1612.01105.pdf

该文的一大贡献就是Pyramid pooling module(简称PPM)

1. PPM有什么用

       一般可以粗略地认为感受野就是使用上下文信息的大小。在很多网络中,我们都很重视全局信息的获取。在FCN中,就是没有充分的场景的上下文信息,导致在一些不同尺度的物体分割上处理不好。

 没有充分利用好场景的上下文信息就会有这些问题(分别如上图所示):

(1)Mismatched Relationship

(2)Confusion Categories

(3)Inconspicuous Classes

总之,PPM就是一种相对较好的充分利用全局信息的方式。这种保留全局信息的思路其实与ASPP(Atrous Spatial Pyramid Pooling) 很相似。从直觉上来看,这种多尺度的pooling确实是可以在不同的尺度下来保留全局信息,比起普通的单一pooling更能保留全局上下文信息。

2 PPM的结构

下面描述下PPM的过程。

原文中采用4种不同金字塔尺度,金字塔池化模块的层数和每层的size是可以修改的。论文中金字塔池化模块是4层,每层的size分别是1×1,2×2,3×3,6×6。

首先,对特征图分别池化到目标size,然后对池化后的结果进行1×1卷积将channel减少到原来的1/N,这里N就为4。接着,对上一步的每一个特征图利用双线性插值上采样得到原特征图相同的size,然后将原特征图和上采样得到的特征图按channel维进行concatenate。得到的channel是原特征图的channel的两倍,最后再用1×1卷积将channel缩小到原来的channel。最终的特征图和原来的特征图size和channel是一样的。

3. PPM代码

class PyramidPooling(nn.Module):
    """Pyramid pooling module"""

    def __init__(self, in_channels, out_channels, **kwargs):
        super(PyramidPooling, self).__init__()
        inter_channels = int(in_channels / 4)   #这里N=4与原文一致
        self.conv1 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)  # 四个1x1卷积用来减小channel为原来的1/N
        self.conv2 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
        self.conv3 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
        self.conv4 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
        self.out = _ConvBNReLU(in_channels * 2, out_channels, 1)  #最后的1x1卷积缩小为原来的channel

    def pool(self, x, size):
        avgpool = nn.AdaptiveAvgPool2d(size)   # 自适应的平均池化,目标size分别为1x1,2x2,3x3,6x6
        return avgpool(x)

    def upsample(self, x, size):    #上采样使用双线性插值
        return F.interpolate(x, size, mode='bilinear', align_corners=True)

    def forward(self, x):
        size = x.size()[2:]
        feat1 = self.upsample(self.conv1(self.pool(x, 1)), size)
        feat2 = self.upsample(self.conv2(self.pool(x, 2)), size)
        feat3 = self.upsample(self.conv3(self.pool(x, 3)), size)
        feat4 = self.upsample(self.conv4(self.pool(x, 6)), size)
        x = torch.cat([x, feat1, feat2, feat3, feat4], dim=1)   #concat 四个池化的结果
        x = self.out(x)
        return x

 

  • 18
    点赞
  • 81
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值