图像分割2-DeepLabv3+实现

目录

一、模型结构

1.Backbone

2.Encoder

3.Decoder

二、模型实现


 

一、模型结构

假设使用(4,3,128,128)的输入,也就是4张图,三通道,图片大小128*128。

1.Backbone

本模型使用resnet101作为backbone,通过backbone之后,得到的输出是(4,2048,8,8),同时我们在resnet101中,第一个block之后,就把结果作为一个底层特征输出出来,形状为(4,256,32,32),这里就是和deeplabv3不一样的地方,这个输出将输入decoder,进行特征融合。

2.Encoder

就是ASPP模块,这个模块中,有1个1*1卷积,3个3*3的膨胀卷积,以及一个全局的pooling,把这五个结果合并起来,在通过一个1*1卷积,就得到了encoder的输出,形状是(4,256,8,8),这个输出将进入decoder,和前面backbone过来的数据进行融合。

3.Decoder

这部分首先把encoder来的特征图进行上采样,使其和底层特征的尺寸一致,将这两个合并之后,在进行卷积,最后再次进行4倍上采样得到最终的输出。

二、模型实现

这里面的F.interpolate就是上采样的方法,和nn.Upsample效果一样。

import torch.nn as nn
import torch
from resnet import ResNet101
import torch.nn.functional as F


# 使用resnet101作为模型的backbone
net = ResNet101()

# ASPP模块
class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation):
        '''这里定义膨胀卷积'''
        super(_ASPPModule, self).__init__()
        self.atrous_conv = nn.Conv2d(inplanes, planes,kernel_size=kernel_size,
                                     padding=padding,dilation=dilation)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)
        return self.relu(x)

class ASPP(nn.Module):
    '''从resnet101的结果过来,通道2048
    这里进行四个卷积+一个pooling,然后合并起来,最后再来一个1*1卷积'''
    def __init__(self):
        super(ASPP, self).__init__()
        inplanes = 2048
        dilations = [1, 6, 12, 18]
        self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0])
        self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1])
        self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2])
        self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3])
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)


    def forward(self, x):
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        # test = nn.Upsample(x4.size()[2:], mode='bilinear', align_corners=True)
        # testout = test(x5)
        x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x1,x2,x3,x4,x5], dim=1)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        return self.dropout(x)

class Decoder(nn.Module):
    '''decoder部分,从resnet中取出的输出上采样后,与ASPP模块的输出合并,再进一步卷积+上采样'''
    def __init__(self, num_classes):
        super(Decoder, self).__init__()
        low_level_inplanes = 256
        self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(48)
        self.relu = nn.ReLU()
        self.last_conv = nn.Sequential(
            nn.Conv2d(304, 256, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(256, num_classes, kernel_size=1, stride=1)
        )

    def forward(self,x,low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)

        x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat([low_level_feat, x], dim=1)
        x = self.last_conv(x)
        return x

class DeepLab3p(nn.Module):
    def __init__(self, num_classes):
        super(DeepLab3p, self).__init__()
        self.backbone = ResNet101()
        self.aspp = ASPP()
        self.decoder = Decoder(num_classes=num_classes)

    def forward(self, x_in):
        x,low_level_feat = self.backbone(x_in)
        x = self.aspp(x)
        x = self.decoder(x, low_level_feat)
        x = F.interpolate(x, x_in.size()[2:], mode='bilinear', align_corners=True)
        return x
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值