基于MobileNet的UNet语义分割模型

What is UNet?

UNet是一个著名的语义分割模型,UNet网络在被提出后,就大范围的用于医学图像、以及自动驾驶场景的分割,甚至是很多全球的AI算法大赛如kaggle等等。而语义分割则是一种经典的深度学习计算机视觉任务,即将不同的类别用不同的RGB展示出来。如图,不同的类别,如人、火车、小轿车,路标等都被不同的颜色体现了出来。

在这里插入图片描述

Why MobileNet?

MobileNet是由谷歌开源的一款轻量级的神经网络backbone。在大多实践和工程项目中,要求推理inference的实时性,即对深度学习模型的特征提取网络的大小有一定的要求,我们今天的主角MobileNet就是性价比极高的一个轻量级网络。而UNet的backbone,即特征提取网络为一个参数量极大的VGG16模型,可想而知很多嵌入式设备是带不动的,更不能得到实时的分割效果。因此,本人想通过使用MobileNet替换VGG16的方式来轻量化我们的UNet模型,使得参数量减少,来达到加速推理的效果。本文中,本人基于pytorch深度学习框架成功修改了网络的backbone,并进行模型融合,提高了模型特征提取的准确性。

MobileNet网络结构介绍

MobileNet使用的核心思想便为depthwise separable convolution(深度可分离卷积)

假设有一个3×3大小的卷积层,其输入通道为16、输出通道为32。具体为,32个3×3大小的卷积核会遍历16个通道中的每个数据,最后可得到所需的32个输出通道,所需参数为16×32×3×3=4608个。

应用深度可分离卷积,用16个3×3大小的卷积核分别遍历16通道的数据,得到了16个特征图谱。在融合操作之前,接着用32个1×1大小的卷积核遍历这16个特征图谱,所需参数为16×3×3+16×32×1×1=656个。
可以看出来depthwise separable convolution可以减少模型的参数。

如下这张图就是深度可分离卷积的结构:
在这里插入图片描述
以下就是MobileNetV1的网络结构,其中第一层是一个普通的卷积块,由于步长stride为2,因此会对图片的长宽进行一次压缩。之后我们可以看到他会经历一个convdw(深度可分离卷积块),以及一次普通的1x1卷积,其中深度可分离卷积用来进行特征提取,1x1卷积块用来调整通道数,当然这里没有加上标准化BN以及激活函数RELU6,这些在后面的代码里会有体现。通过不断的convdw和1x1conv的叠加,最后通过平均池化和全连接通过softmax函数输出结果。这就是整个MobileNet的网络结构,接下来我们来通过代码,基于pytorch搭建一下网络结构。

pytorch搭建MobileNet网络

import time
import torch
import torch.nn as nn
from torchsummary import summary
import torch.nn.functional as F
import torchvision.models as models
import torchvision.models._utils as _utils
from torch.autograd import Variable


# conv_bn为网络的第一个卷积块,步长为2
def conv_bn(inp, oup, stride=1):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


# conv_dw为深度可分离卷积
def conv_dw(inp, oup, stride=1):
    return nn.Sequential(
        # 3x3卷积提取特征,步长为2
        nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
        nn.BatchNorm2d(inp),
        nn.ReLU6(inplace=True),

        # 1x1卷积,步长为1
        nn.Conv2d(inp, oup, 1, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True),
    )


class MobileNet(nn.Module):
    def __init__(self, n_channels):
        super(MobileNet, self).__init__()
        self.layer1 = nn.Sequential(
            # 第一个卷积块,步长为2,压缩一次
            conv_bn(n_channels, 32, 1),  # 416,416,3 -> 208,208,32

            # 第一个深度可分离卷积,步长为1
            conv_dw(32, 64, 1),  # 208,208,32 -> 208,208,64

            # 两个深度可分离卷积块
            conv_dw(64, 128, 2),  # 208,208,64 -> 104,104,128
            conv_dw(128, 128, 1),

            # 104,104,128 -> 52,52,256
            conv_dw(128, 256, 2),
            conv_dw(256, 256, 1),
        )
        # 52,52,256 -> 26,26,512
        self.layer2 = nn.Sequential(
            conv_dw(256, 512, 2),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
        )
        # 26,26,512 -> 13,13,1024
        self.layer3 = nn.Sequential(
            conv_dw(512, 1024, 2),
            conv_dw(1024, 1024, 1),
        )
        self.avg = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(1024, 1000)

    def forward(self, x):
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.avg(x)

        x = x.view(-1, 1024)
        x = self.fc(x)
        return x

如何加入到UNet网络中?

我们刚刚通过pytorch搭建了MobileNet的网络结构。接下来我们需要将MobileNet加入到UNet中,替换之前的backbone。首先我们先看一下UNet的网络结构:
在这里插入图片描述
我们通过找到UNet中压缩次数与MobileNet中压缩次数相同的feature map(特征层),将二者对应替换,并与后面上采样得到的特征层进行堆叠(torch.cat),则可以实现整个模型的成功替换,实现代码如下:

from mobilenet.mobile import MobileNet
import torch.nn as nn
from collections import OrderedDict
import torch
import torchsummary as summary

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


def conv2d(filter_in, filter_out, kernel_size, groups=1, stride=1):
    pad = (kernel_size - 1) // 2 if kernel_size else 0
    return nn.Sequential(OrderedDict([
        ("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, groups=groups, bias=False)),
        ("bn", nn.BatchNorm2d(filter_out)),
        ("relu", nn.ReLU6(inplace=True)),
    ]))


class mobilenet(nn.Module):
    def __init__(self, n_channels):
        super(mobilenet, self).__init__()
        self.model = MobileNet(n_channels)

    def forward(self, x):
        out3 = self.model.layer1(x)
        out4 = self.model.layer2(out3)
        out5 = self.model.layer3(out4)

        return out3, out4, out5


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_channels, num_classes):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.num_classes = num_classes

        # ---------------------------------------------------#
        #   64,64,256;32,32,512;16,16,1024
        # ---------------------------------------------------#
        self.backbone = mobilenet(n_channels)

        self.up1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv1 = DoubleConv(1024, 512)

        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv2 = DoubleConv(1024, 256)

        self.up3 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv3 = DoubleConv(512, 128)

        self.up4 = nn.Upsample(scale_factor=2, mode='nearest')
        #nn.Upsample(scale_factor=2, mode='bilinear')
        self.conv4 = DoubleConv(128, 64)

        self.oup = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        #  backbone
        x2, x1, x0 = self.backbone(x)
        # print(f"x2.shape: {x2.shape}, x1: {x1.shape}, x0: {x0.shape} ")

        P5 = self.up1(x0)
        P5 = self.conv1(P5)           # P5: 26x26x512
        # print(P5.shape)
        P4 = x1                       # P4: 26x26x512
        P4 = torch.cat([P4, P5], axis=1)   # P4(堆叠后): 26x26x1024
        # print(f"cat 后是: {P4.shape}")

        P4 = self.up2(P4)             # 52x52x1024
        P4 = self.conv2(P4)           # 52x52x256
        P3 = x2                       # x2 = 52x52x256
        P3 = torch.cat([P4, P3], axis=1)  # 52x52x512

        P3 = self.up3(P3)
        P3 = self.conv3(P3)

        P3 = self.up4(P3)
        P3 = self.conv4(P3)

        out = self.oup(P3)
        # print(f"out.shape is {out.shape}")

        return out

可以看到在UNet的类中,我们通过输出MobileNet即backbone的三个特征层的结果,再进行上采样,即可以在后面的前向传播部分进行堆叠。

完整项目工程

参见:
https://github.com/YZY-stack/UNet-MobileNet-Pytorch
欢迎大家star或fork我的项目,接下来会继续完善和改进我的代码,也欢迎大家提问、批评指正。

  • 15
    点赞
  • 141
    收藏
    觉得还不错? 一键收藏
  • 19
    评论
评论 19
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值