【语义分割】5、PSPNet:Pyramid Scene Parsing Network

本文介绍了PyramidPoolingModule(PPM),它通过聚合不同区域的上下文信息来增强全局理解能力。PPM利用不同尺度的全局平均池化,结合1x1卷积和上采样,将多个尺度的特征融合,用于场景解析任务。PSPHead是一种使用PPM的头部结构,其在PSPNet中被提出,通过多个PPM模块和瓶颈层处理特征,提高了分割预测的准确性。
摘要由CSDN通过智能技术生成


在这里插入图片描述

一、主要思想

提出了pyramid pooling module (PPM) 模块,聚合不同区域的上下文信息,从而提高获取全局信息的能力。

现有的深度网络方法中,某一个操作的感受野直接决定了这个操作可以获得多少上下文信息,所以提升感受野可以为网络引入更多的上下文信息。

二、方法

在这里插入图片描述
Step1:使用global averag pooling得到不同尺度的特征,PPM模块融合了4个不同尺度的特征:

  • 红色是最粗糙尺度,使用一个global average pooling 实现
  • 其他的都是将特征图切分为不同数量的块,在每个块内使用global average pooling (文中四个尺度分别是 1x1, 2x2, 3x3, 6x6)

Step2:global average pooling 之后,每层都接一个1x1的卷积来降低通道维度。

Step3:上采样到和原图相同的尺寸,然后和进入PPM头之前的feature map 进行concat 来预测结果。

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

import torch
import torch.nn as nn
from mmcv.cnn import ConvModule

from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead
from .Attention_layer import HardClassAttention as HCA

class PPM(nn.ModuleList):
    """Pooling Pyramid Module used in PSPNet.

    Args:
        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
            Module.
        in_channels (int): Input channels.
        channels (int): Channels after modules, before conv_seg.
        conv_cfg (dict|None): Config of conv layers.
        norm_cfg (dict|None): Config of norm layers.
        act_cfg (dict): Config of activation layers.
        align_corners (bool): align_corners argument of F.interpolate.
    """

    def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
                 act_cfg, align_corners):
        super(PPM, self).__init__()
        self.pool_scales = pool_scales
        self.align_corners = align_corners
        self.in_channels = in_channels
        self.channels = channels
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
        for pool_scale in pool_scales:
            self.append(
                nn.Sequential(
                    nn.AdaptiveAvgPool2d(pool_scale),
                    ConvModule(
                        self.in_channels,
                        self.channels,
                        1,
                        conv_cfg=self.conv_cfg,
                        norm_cfg=self.norm_cfg,
                        act_cfg=self.act_cfg)))

    def forward(self, x):
        """Forward function."""
        ppm_outs = []
        for ppm in self:
            ppm_out = ppm(x)
            upsampled_ppm_out = resize(
                ppm_out,
                size=x.size()[2:],
                mode='bilinear',
                align_corners=self.align_corners)
            ppm_outs.append(upsampled_ppm_out)
        return ppm_outs


@HEADS.register_module()
class PSPHead(BaseDecodeHead):
    """Pyramid Scene Parsing Network.

    This head is the implementation of
    `PSPNet <https://arxiv.org/abs/1612.01105>`_.

    Args:
        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
            Module. Default: (1, 2, 3, 6).
    """

    def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
        super(PSPHead, self).__init__(**kwargs)
        assert isinstance(pool_scales, (list, tuple))
        self.pool_scales = pool_scales
        self.psp_modules = PPM(
            self.pool_scales,
            self.in_channels,
            self.channels,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg,
            align_corners=self.align_corners)
        self.bottleneck = ConvModule(
            self.in_channels + len(pool_scales) * self.channels,
            self.channels,
            3,
            padding=1,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)

    def forward(self, inputs):
        """Forward function."""
        # inputs [4, 512, 64, 128]
        x = self._transform_inputs(inputs) #[4, 2048, 64, 128]
        psp_outs = [x]  # list, len=1, psp_outs[0].shape = [4, 2048, 64, 128]
        # self.psp_models(x), list, len=4
        psp_outs.extend(self.psp_modules(x)) # len(psp_outs) = 5, psp_out[1-4].shape = [4, 512, 64, 128]
        psp_outs = torch.cat(psp_outs, dim=1) # [4, 4096, 64, 128]
        output = self.bottleneck(psp_outs)    # [4, 512, 64, 128]
        output = self.cls_seg(output)         # [4, 19, 64, 128]
        # import pdb; pdb.set_trace()
        return output
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

呆呆的猫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值