DeepLabV3+:ASPP加强特征提取网络的搭建

目录

ASPP结构介绍

ASPP在代码中的构建

参考资料


ASPP结构介绍

ASPP:Atrous Spatial Pyramid Pooling,空洞空间卷积池化金字塔。
简单理解就是个至尊版池化层,其目的与普通的池化层一致,尽可能地去提取特征。

利用主干特征提取网络,会得到一个浅层特征和一个深层特征,这一篇主要以如何对较深层特征进行加强特征提取,也就是在Encoder中所看到的部分。

它就叫做ASPP,主要有5个部分:

  • 1x1卷积
  • 膨胀率为6的3x3卷积
  • 膨胀率为12的3x3卷积
  • 膨胀率为18的3x3卷积
  • 对输入进去的特征层进行池化

接着会对这五个部分进行一个堆叠,再利用一个1x1卷积对通道数进行调整,获得上图中绿色的特征。

ASPP在代码中的构建

import torch
import torch.nn as nn
import torch.nn.functional as F

class ASPP(nn.Module):
    def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
        super(ASPP, self).__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=(1,1), stride=(1,1), padding=0, dilation=rate, bias=True),
            nn.BatchNorm2d(dim_out, momentum=bn_mom),
            nn.ReLU(inplace=True),
        )
        self.branch2 = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=(3,3), stride=(1,1), padding=6 * rate, dilation=6 * rate, bias=True),
            nn.BatchNorm2d(dim_out, momentum=bn_mom),
            nn.ReLU(inplace=True),
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=(3,3), stride=(1,1), padding=12 * rate, dilation=12 * rate, bias=True),
            nn.BatchNorm2d(dim_out, momentum=bn_mom),
            nn.ReLU(inplace=True),
        )
        self.branch4 = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=(3,3), stride=(1,1), padding=18 * rate, dilation=18 * rate, bias=True),
            nn.BatchNorm2d(dim_out, momentum=bn_mom),
            nn.ReLU(inplace=True),
        )
        self.branch5_conv = nn.Conv2d(dim_in, dim_out, kernel_size=(1,1), stride=(1,1), padding=0, bias=True)
        self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
        self.branch5_relu = nn.ReLU(inplace=True)

        self.conv_cat = nn.Sequential(
            nn.Conv2d(dim_out * 5, dim_out ,kernel_size=(1,1), stride=(1,1), padding=0, bias=True),
            nn.BatchNorm2d(dim_out, momentum=bn_mom),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        [b, c, row, col] = x.size()
  
        # 五个分支
        conv1x1 = self.branch1(x)
        conv3x3_1 = self.branch2(x)
        conv3x3_2 = self.branch3(x)
        conv3x3_3 = self.branch4(x)
   
        # 第五个分支,进行全局平均池化+卷积
        global_feature = torch.mean(x, 2, True)
        global_feature = torch.mean(global_feature, 3, True)
        global_feature = self.branch5_conv(global_feature)
        global_feature = self.branch5_bn(global_feature)
        global_feature = self.branch5_relu(global_feature)
        global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
  
        # 五个分支的内容堆叠起来,然后1x1卷积整合特征。
        feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
        result = self.conv_cat(feature_cat)
        return result


if __name__ == "__main__":
    model = ASPP(dim_in=320, dim_out=256, rate=16//16)
    print(model)

那么从这里来看的话,也是相当清晰的,branch*(1、2、3、4、5)分别代表了ASPP五个部分在def __init__()可以体现,对于每一个都是卷积、标准化、激活函数。

第五个部分可以看到def forward中,首先呢,是要进行一个全局平均池化,再用1x1卷积通道数的整合,标准化、激活函数,接着采用上采样的方法,把它的大小调整成和我们上面获得的分支一样大小的特征层,这样我们才可以将五个部分进行一个堆叠,使用的是torch.cat()函数实现,最后,利用1x1卷积,对输入进来的特征层进行一个通道数的调整,获得想上图中绿色的部分,接着就会将这个具有较高语义信息的有效特征层就会传入到Decoder当中。

参考资料

(6条消息) Pytorch-torchvision源码解读:ASPP_xiongxyowo的博客-CSDN博客_aspp代码

DeepLabV3-/deeplabv3+.pdf at main · Auorui/DeepLabV3- (github.com)

  • 13
    点赞
  • 48
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

夏天是冰红茶

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值