图像语义分割 -- U-Net

一:FCN回顾

上一博文我们学习了FCN,有不同的特征融合版本。
至于为什么要进行特征能融合呢?由于池化操作的存在,浅层卷积视野小,具体一些,细节更加详细,越深层的视野大,图像越小,越粗粒度,细节也是越来越模糊,所以,下采样的好处是,带来了感受域的提升,同时也减少计算量,但是却忽略了很多细节,让图像变得平湖模糊,因此,作者将浅层的细节特征也进行了特征融合。

较浅的卷积层(靠前的)的感受域比较小,学习感知细节部分的能力强,较深的隐藏层 (靠后的),感受域相对较大,适合学习较为整体的、相对更宏观一些的特征。
所以在较深的卷积层上进行反卷积还原,自然会丢失很多细节特征。
于是我们会在反卷积步骤时,考虑采用一部分较浅层的反卷积信息辅助叠加,更好的优化分割结果的精度:

至于效果具体是如何呢?
作者在原文种给出3种网络结果对比,明显可以看出效果:FCN-32s < FCN-16s < FCN-8s,即使用多层feature融合有利于提高分割准确性。
在这里插入图片描述

二:U-Net
Unet 基于 Encoder-Decoder 结构,通过拼接的方式实现特征融合,结构简明且稳定,如果你有语义分割的问题,尤其在样本数据量不大的情况下,表现还是可以的。其图示如下:

在这里插入图片描述

如上图,Unet 网络结构是对称的,形似英文字母 U 所以被称为 Unet。整张图都是由蓝/白色框与各种颜色的箭头组成,其中,蓝/白色框表示 feature map;蓝色箭头表示 3x3 卷积,用于特征提取;灰色箭头表示 skip-connection,用于特征融合;红色箭头表示池化 pooling,用于降低维度;绿色箭头表示上采样 upsample,用于恢复维度;青色箭头表示 1x1 卷积,用于输出结果。

Encoder 由卷积操作和下采样操作组成,文中所用的卷积结构统一为 3x3 的卷积核,padding 为 0 ,striding 为 1。pytorch 代码:

nn.Sequential(nn.Conv2d(in_channels, out_channels, 3),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(inplace=True))

另外,Encoder中的下采样采用的是maxpooling。pytorch 代码:

nn.MaxPool2d(kernel_size=2, stride=2)

Decoder中feature map 经过 Decoder 恢复原始分辨率,该过程除了卷积比较关键的步骤就是 upsampling 与 skip-connection。

Upsampling 上采样常用的方式有两种:1.FCN 中介绍的反卷积;2. 插值。其中在插值方法中,bilinear 双线性插值的综合表现较好也较为常见。pytorch 代码:

nn.Upsample(scale_factor=2, mode='bilinear')

可用以下例子看看bilinear插值的效果。

import torch
from torch import nn

x = torch.rand(2, 3, 3, 2)
model = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)  # [2, 3, 6, 4]
model = nn.Upsample(scale_factor=3, mode='bilinear', align_corners=True)  # [2, 3, 9, 6]
y = model(x)
print(y.shape)

FNN 网络要想获得好效果,skip-connection 基本必不可少。Unet 的Decoder中这一关键步骤融合了底层信息的位置信息与深层特征的语义信息,pytorch 代码:

torch.cat([low_layer_features, deep_layer_features], dim=1)

这里需要注意的是,FCN 中深层信息与浅层信息融合是通过对应像素相加的方式,而 Unet 是通过拼接的方式。测试代码如下:

import torch
from torch import nn

low_layer_features = torch.rand(2, 3, 3, 2)
deep_layer_features = torch.rand(2, 3, 3, 2)
y = torch.cat([low_layer_features, deep_layer_features], dim=1)  # [2, 6, 3, 2]
print(y.shape)

三:U-Net具体代码实现
好了,U-Net的结构也是分析完了,关键的步骤操作和试验也差不多了,现在我们来搭建下U-Net网络吧。完整代码如下:

from torch import nn
import torch


class UNet(nn.Module):
    def __init__(self, in_channels=1, num_classes=2):  # num_classes,此处为 二分类值为2
        super(UNet, self).__init__()
        # == Encoder ==
        # 1. extract feayures, conv1
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.subpool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 2. extract feayures, conv2
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 128, 3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.subpool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 3. extract feayures, conv3
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, 3),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, 3),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.subpool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 4. extract feayures, conv4
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, 3),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 512, 3),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        self.subpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 5. extract feayures, conv5
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 1024, 3),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),

            nn.Conv2d(1024, 512, 3),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )

        # == Decoder ==
        self.uppool1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv6 = nn.Sequential(
            nn.Conv2d(1024, 512, 3),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            nn.Conv2d(512, 256, 3),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.uppool2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv7 = nn.Sequential(
            nn.Conv2d(512, 256, 3),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 128, 3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.uppool3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv8 = nn.Sequential(
            nn.Conv2d(256, 128, 3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.uppool4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv9 = nn.Sequential(
            nn.Conv2d(128, 64, 3),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, num_classes, 1),
            nn.BatchNorm2d(num_classes),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # === encoder
        conv1 = self.conv1(x)
        conv1_sub = self.subpool1(conv1)

        conv2 = self.conv2(conv1_sub)
        conv2_sub = self.subpool2(conv2)

        conv3 = self.conv3(conv2_sub)
        conv3_sub = self.subpool3(conv3)

        conv4 = self.conv4(conv3_sub)
        conv4_sub = self.subpool4(conv4)

        conv5 = self.conv5(conv4_sub)  # U型的最低端,它既是是encoder输出,也是decoder的输入。

        # === deoder
        conv1_up = self.uppool1(conv5)
        conv6 = self.conv6(torch.cat([conv4, conv1_up], dim=1))

        conv2_up = self.uppool2(conv6)
        conv7 = self.conv7(torch.cat([conv3, conv2_up], dim=1))

        conv3_up = self.uppool3(conv7)
        conv8 = self.conv8(torch.cat([conv2, conv3_up], dim=1))

        conv4_up = self.uppool4(conv8)
        conv9 = self.conv9(torch.cat([conv1, conv4_up], dim=1))

        return conv9


if __name__ == '__main__':
    # model = VGGTest()
    x = torch.rand(64, 1, 572, 572)
    print(x.shape)

    model = UNet(in_channels=x.shape[1])
    # print(model)
    y = model(x)
    print(y.shape)

四:和FCN的区别对比
U-Net采用了与FCN完全不同的特征融合方式
与FCN逐点相加不同,U-Net采用将特征在channel维度拼接在一起,形成更“厚”的特征。所以:
语义分割网络,在浅层和深层特征融合时也有2种办法:

  1. FCN式的浅层特征和深层特征逐点相加。
  2. U-Net式的channel维度拼接融合。
    相比其他大型网络,FCN/U-Net还是蛮简单的,就不多废话了。
    总结一下,CNN图像语义分割也就基本上是这个套路:
  3. 下采样+上采样:Convlution + Deconvlution/Resize
  4. 多层次特征融合:特征逐点相加/特征channel维度拼接
  5. 获得像素级别的segement map:对每一个像素点进行判断类别
  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值