【语义分割网络系列】二、U-Net

本文深入解析了U-Net网络结构,包括Encoder-Decoder设计、上采样、跳跃连接、损失函数及其在图像分割中的应用。U-Net通过结合上下文信息和位置细节,解决了小样本情况下深度网络的挑战,并使用数据增强提升模型性能。此外,介绍了网络的Pytorch实现,展示了如何构建和训练U-Net模型。
摘要由CSDN通过智能技术生成


参考资料

论文

  U-Net: Convolutional Networks for Biomedical Image Segmentation

博客

  U-Net原理分析与代码解读

  图像分割之U-Net


第1章 前言

 深度网络通常需要大量的数据进行训练,当样本量较少的情况下,深度网络可能表现没那么好。对于这个问题,本文提出了新的网络架构和图像增强策略。网络架构包括encoder和decoder(论文中称为浓缩路径(contracting path)和扩展路径(expanding path),实际上就是encoder和decoder),encoder可以有效地捕捉上下文信息,而decoder可以较好地预测位置信息

  • Encoder:使得模型理解了图像的内容,但是丢弃了图像的位置信息。
  • Decoder:使模型结合Encoder对图像内容的理解,恢复图像的位置信息。

 网络的浓缩路径,图像分辨率逐渐降低,上下文信息会逐渐增强。在扩展路径中,通过上采样的方式,让特征图的分辨率逐渐增大。同时,为了结合低层feature map的强位置信息,将浓缩路径中的相应部分结合到扩展路径中。这种架构可以较好地进行位置定位。

U-Net做的修改有

  1. 在上采样部分,Feature Map的通道数非常大,作者认为这样可以将上下文信息传递到分辨率更高的层当中。这样做的一个结果就是它基本上和浓缩路径对称了,因此看上去像一个U形的结构。
  2. 为了预测图像边界区域的像素点,采用 overlap-tile 策略补全缺失的context。
  3. 由于训练数据太少,采用大量 弹性形变 的方式增强数据。这可以让模型更好学习形变不变性。这种增强方式对于医学图像来说很重要。
  4. 在细胞分割任务中的另一个挑战是,如何将同类别的相互接触的目标分开。本文提出了使用一种 带权重的损失(weighted loss) 。在损失函数中,分割相互接触的细胞像素获得了更大的权重。

第2章 U-Net网络结构

 U-Net 跟 FCN 都是 Encoder-Decoder 结构,结构简单但很有效。Encoder 负责特征提取,你可以将自己熟悉的各种特征提取网络放在这个位置。由于在医学方面,样本收集较为困难,作者为了解决这个问题,应用了图像增强的方法,在数据集有限的情况下获得了不错的精度。

在这里插入图片描述

如上图,U-Net 网络结构是对称的,形似英文字母 U 所以被称为 U-Net。整张图都是由蓝/白色框与各种颜色的箭头组成,其中:

  • 蓝/白色框表示 feature map
  • 蓝色箭头表示 3x3 卷积,用于特征提取;
  • 灰色箭头表示 skip-connection,用于特征融合;
  • 红色箭头表示池化 pooling,用于降低维度;
  • 绿色箭头表示上采样 upsample,用于恢复维度;
  • 青色箭头表示 1x1 卷积,用于输出结果。
  • 其中灰色箭头copy and crop中的copy就是维度上的concatenate,而crop是裁剪为了让两者的长宽一致;

 可能你会问为啥是 5 层而不是 4 层或者 6 层,emmm,这应该去问作者本人,可能对于当时作者拿到的数据集来说,这个层数的表现更好,但不代表所有的数据集这个结构都适合。我们该多关注这种 Encoder-Decoder 的设计思想,具体实现则应该因数据集而异。


2.1 Encoder

 Encoder 由卷积操作和下采样操作组成,文中所用的卷积结构统一为 3x3 的卷积核,padding 为 0 ,striding 为 1。没有 padding 所以每次卷积之后 feature map 的 H 和 W 变小了,在 skip-connection 时要注意 feature map 的维度(其实也可以将 padding 设置为 1 避免维度不对应问题),pytorch 代码:

nn.Sequential(nn.Conv2d(in_channels, out_channels, 3),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(inplace=True))

 上述的两次卷积之后是一个 stride 为 2 的 max pooling,输出大小变为 1 2 × ( H , W ) \frac{1}{2} \times(H, W) 21×(H,W)

在这里插入图片描述

pytorch 代码:

nn.MaxPool2d(kernel_size=2, stride=2)

 上面的步骤重复 5 次,最后一次没有 max-pooling,直接将得到的 feature map 送入 Decoder。


2.2 Decoder

 feature map 经过 Decoder 恢复原始分辨率,该过程除了卷积之外比较关键的步骤就是 upsamplingskip-connection

(1)Upsampling

 Upsampling 上采样常用的方式有两种:1.FCN 中介绍的反卷积;2. 插值

 这里介绍文中使用的插值方式。在插值实现方式中,bilinear(双线性插值)的综合表现较好也较为常见 。

双线性插值的计算过程没有需要学习的参数,实际就是套公式,这里举个例子方便大家理解(例子介绍的是参数 align_corners Fasle 的情况)。

在这里插入图片描述

pytorch 里使用 bilinear 插值:

nn.Upsample(scale_factor=2, mode='bilinear')

(2)Skip-Connection

 CNN 网络要想获得好效果,skip-connection 基本必不可少。U-Net 中这一关键步骤融合了底层信息的位置信息与深层特征的语义信息,pytorch 代码:

torch.cat([low_layer_features, deep_layer_features], dim=1)

 这里需要注意,FCN 中深层信息与浅层信息融合是通过对应像素相加的方式,而 U-Net 是通过拼接的方式

 那么这两者有什么区别呢,其实 在 ResNet 与 DenseNet 中也有一样的区别,Resnet 使用了对应值相加,DenseNet 使用了拼接。

个人理解在相加的方式下,feature map 的维度没有变化,但每个维度都包含了更多特征,对于普通的分类任务这种不需要从 feature map 复原到原始分辨率的任务来说,这是一个高效的选择;而拼接则保留了更多的维度/位置 信息,这使得后面的 layer 可以在浅层特征与深层特征自由选择,这对语义分割任务来说更有优势。


2.3 损失函数

(1)损失函数计算

 ISBI数据集的一个非常严峻的挑战是紧密相邻的物体之间的分割问题。如图3所示,(a)是输入数据,(b)是Ground Truth,©是基于Ground Truth生成的分割掩码,(d)是U-Net使用的用于分离边界的损失权值。

在这里插入图片描述

 网络输出的是pixel-wise的softmax。表达式如下:

在这里插入图片描述

 其中, x x x 为二维平面 Ω Ω 上的像素位置, a k ( x ) a_k(x) ak(x) 表示网络最后输出层中pixel x x x 对应的第 k k k 个通道的值, K K K 是类别总数。 p k ( x ) p_k(x) pk(x) 表示像素 x x x 属于 k k k 类的概率。

 损失函数使用 negative cross entropy 。cross entropy的数学表达式如下:

在这里插入图片描述

 其中 p l ( x ) p_l(x) pl(x) 表示 x x x 在真实label所在通道上的输出概率。特别注意的是cross entropy中还添加一个权重项 w ( x ) w(x) w(x)这是因为考虑到物体间的边界需要更多的关注,所对应的损失权重需要更大

(2)像素损失权重计算

 我们得到一张图片的ground truth是一个二值的mask,本文首先采用形态学方法去计算出物体的边界。然后通过以下的表达式去计算权重图。

在这里插入图片描述

 其中 w c ( x ) w_c(x) wc(x)类别权重,需要根据训练数据集中的各类别出现的频率来进行统计,类别出现的频率越高,应该给的权重越低,频率越低则给的权重越高(文章没有详细说是怎么计算的)。

d 1 ( x ) d_1(x) d1(x) 表示物体像素到最近cell的边界的距离, d 2 ( x ) d_2(x) d2(x) 表示物体像素到第二近的cell的边界的距离。在本文中,设置 w 0 = 10 , σ = 5 w_0=10,σ=5 w0=10,σ=5


2.4 数据扩充

 由于训练集只有30张训练样本,作者使用了数据扩充的方法增加了样本数量。并且作者指出任意的弹性形变对训练非常有帮助。


第3章 总结

 U-Net是比较早的使用多尺度特征进行语义分割任务的算法之一,基于 Encoder-Decoder 结构,通过拼接的方式实现特征融合,结构简明且稳定。其U形结构也启发了后面很多算法。但其也有几个缺点:

  1. 有效卷积增加了模型设计的难度和普适性;目前很多算法直接采用了same卷积,这样也可以免去Feature Map合并之前的裁边操作
  2. 其通过裁边的形式和Feature Map并不是对称的,个人感觉采用双线性插值的效果应该会更好。

第4章 Pytorch实现U-Net

参考

  U-Net网络原理分析与pytorch实现


class U_Net(nn.Module):
    def __init__(self):
        super().__init__()

        # 首先定义左半部分网络
        # left_conv_1 表示连续的两个(卷积+激活)
        # 随后进行最大池化
        self.left_conv_1 = ConvBlock(in_channels=3, middle_channels=64, out_channels=64)
        self.pool_1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.left_conv_2 = ConvBlock(in_channels=64, middle_channels=128, out_channels=128)
        self.pool_2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.left_conv_3 = ConvBlock(in_channels=128, middle_channels=256, out_channels=256)
        self.pool_3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.left_conv_4 = ConvBlock(in_channels=256, middle_channels=512, out_channels=512)
        self.pool_4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.left_conv_5 = ConvBlock(in_channels=512, middle_channels=1024, out_channels=1024)

        # 定义右半部分网络
        self.deconv_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.right_conv_1 = ConvBlock(in_channels=1024, middle_channels=512, out_channels=512)

        self.deconv_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, padding=1, stride=2, output_padding=1)
        self.right_conv_2 = ConvBlock(in_channels=512, middle_channels=256, out_channels=256)

        self.deconv_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, padding=1, stride=2 ,output_padding=1)
        self.right_conv_3 = ConvBlock(in_channels=256, middle_channels=128, out_channels=128)

        self.deconv_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, output_padding=1, padding=1)
        self.right_conv_4 = ConvBlock(in_channels=128, middle_channels=64, out_channels=64)
        # 最后是1x1的卷积,用于将通道数化为3
        self.right_conv_5 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1, stride=1, padding=0)

    def forward(self, x):

        # 1:进行编码过程
        feature_1 = self.left_conv_1(x)
        feature_1_pool = self.pool_1(feature_1)

        feature_2 = self.left_conv_2(feature_1_pool)
        feature_2_pool = self.pool_2(feature_2)

        feature_3 = self.left_conv_3(feature_2_pool)
        feature_3_pool = self.pool_3(feature_3)

        feature_4 = self.left_conv_4(feature_3_pool)
        feature_4_pool = self.pool_4(feature_4)

        feature_5 = self.left_conv_5(feature_4_pool)

        # 2:进行解码过程
        de_feature_1 = self.deconv_1(feature_5)
        # 特征拼接
        temp = torch.cat((feature_4, de_feature_1), dim=1)
        de_feature_1_conv = self.right_conv_1(temp)

        de_feature_2 = self.deconv_2(de_feature_1_conv)
        temp = torch.cat((feature_3, de_feature_2), dim=1)
        de_feature_2_conv = self.right_conv_2(temp)

        de_feature_3 = self.deconv_3(de_feature_2_conv)

        temp = torch.cat((feature_2, de_feature_3), dim=1)
        de_feature_3_conv = self.right_conv_3(temp)

        de_feature_4 = self.deconv_4(de_feature_3_conv)
        temp = torch.cat((feature_1, de_feature_4), dim=1)
        de_feature_4_conv = self.right_conv_4(temp)

        out = self.right_conv_5(de_feature_4_conv)

        return out
测试网络输入和输出的尺寸是否一致:

if __name__ == "__main__":
    x = torch.rand(size=(8, 3, 224, 224))
    net = U_Net()
    out = net(x)
    print(out.size())
    print("ok")
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

travellerss

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值