U-Net: Convolutional Networks for Biomedical Image Segmentation

U-Net是图像语义分割领域的经典著作,由Olaf Ronneberger于2015年在Medical Image Computing and Computer-Assisted Intervention (MICCAI)提出。Olaf Ronneberger的工作多与生物图像有关,继2015年后又使用U-Net网络发表了几项工作。根据Olaf Ronneberger对U-Net的重视程度,U-Net的有效性可见一斑。此后图像领域一系列基于U-Net结构的网络不断被提出,deeplab最新的deeplabv3+在加上了decoder作为改进手段。基于以上,记录这项工作,理解其方法背后的思想是很有必要的,故有此篇博客。

U-Net是一项2015年的工作,此时期的神经网络结构是简单的,使用端到端的全卷积网络就能发表一篇不错的工作。U-Net虽说是2015年的工作,但是要完成一个工作一般要提前半年,理解一个工具、查阅文献以及实现idea直到收录前前后后快了差不多也要一年时间。因此,U-Net的思想可能在2014年就提出了。没啥话说,给大佬树个大拇哥👍!

在这里插入图片描述

U-Net: Convolutional Networks for Biomedical Image Segmentation

Abstract

业界一致认为成功地训练一个神经网络需要依赖大量的标注样本。在本文中,作者提出了一种神经网络,该网络对数据进行了增广以便更有效地使用现有标注数据。网络结构包括一个收缩路径(contracting path)和一个与其对称的扩张路径(expanding path),分别用于提取信息和实现精确定位。现在不难想到contracting path是应该就是encoder部分,expanding path对应的是decoder部分。作者发现,该结构可以在较少数据上进行端到端的训练,在ISBI语义分割挑战中效果超过了之前最好的方法(滑窗形式的卷积网络)。此外,U-Net的速度很快,在最新的GPU上分割一张 512 × 512 512 \times 512 512×512的图像仅需1s不到的时间。Caffe代码实现在这里

Introduction

在过去两年(2015之前两年)深度卷积网络在许多视觉任务中都取得了SOTA的效果。虽然卷积神经网络已经存在很长时间了(1989年伊始),但是其受限于数据量的大小以及网络的尺寸。首次突破是Krizhevsky提出的拥有百万参数的8层大网络,该网络是在ImageNet数据集上训练的。从这时起,越来越多、越来越深的网络开始涌现。

卷积神经网络的一个典型应用场景是图像分类,该任务中每个输入图像对应输出一个类别标签。然而,在许多视觉任务中,尤其是生物医学图像处理,输出一般包含定位,比如为每个像素分配一个类别标签。此外,在生物医学任务中,成千上万的训练数据往往是无法获得的(算是U-Net这个idea的一个Motivation)。因此,Ciresan想了一个办法,他不训练整张图像而是训练图像中的patch。首先,该工作可以实现定位,其次训练patch的数据量远大于图像的数量。该网络最终在ISBI 2012的EM分割挑战中取得大胜。

Ciresan提出的方法有两个明显的缺点:首先,该网络结构非常慢,因为需要分别训练各个patch,并且patch之间有很大的重叠,所以网络的冗余度较大。其次,网络的定位精度和语义信息之间要做一个均衡。大的patch需要使用更多的max-pooling layer,这样会丢失语义信息,降低定位精度,但是小的patch可见的语义范围又是有限的。最近的研究(2015年之前)提出将多层的特征也进行分类输出,如此同时实现精确定位和提取语义信息。

本文提出了一个更加优雅的网络结构,所谓的全卷积网络。作者修改并拓展了全卷积网络,使其可以在非常少的训练图像上生成更加精确的分割效果,见图1。全卷积网络的主要思想就是将网络层中的特征缩小后上采样再卷积输出,因此这些层提升了输出的分辨率。为了精确定位,U-Net将缩小路径上的特征与上采样后的输出结合。随后的卷积层将基于这些信息学习,以聚合更精确的输出。

U-Net的一个重要的改进在于上采样部分包含很多的特征通道,使得特征信息可以传播到拥有更高分辨率的层。因此放大路径(网络图的右半边)或多或少和缩小路径(左半边)是对称的。U-Net没有使用任何的全连接层,只使用了卷积层。分割图像里只包含像素,而完整的语义信息可以从输入图像中获得。同时U-Net可以对任意尺寸的输入图像进行语义分割。从网络结构中可以看出,网络的输出分辨率比输入分辨率小。所以预测某分辨率的分割图像需要输入比它大的输入图像,使用镜像操作来补充预测区域的周围区域,见图2。U-Net这样做每次预测输入图像的一个区域,把网络应用到大图像上时就可以分别区域预测,最后拼接到一起。这样做也使得在预测大图时不再收到GPU显存的限制。没咋看懂论文写的意思,可能就是预测patch,最后拼到一起。
U-Net网络结构
在这里插入图片描述
对于本文的任务,可获取的训练数据非常稀少,因此在训练过程中使用了大量的数据增广,具体是伸缩变换。网络可以学习使用这类变换后的不变特征,且这类变换不需要施加在标注图像中。变换在生物医学分割中是很重要的,因为变换是在组织中最常用的且真实的变换可以有效地模拟真实数据。数据增广对学习不变性的价值也在Dosovitskiy[2]提出的无监督特征学习中得到体现。

细胞分割任务中的另一个挑战就是分离互相接触的同类对象,见图3。为此,作者提出了加权损失,使接触的细胞之间的背景标签在损失函数中获得大的权重,以此来区分分割对象之间的边界。
在这里插入图片描述
U-Net被应用到了多个生物医学分割问题中,本文展示了EM stacks 中的分割结果并超越了Ciresan提出的结构。此外,作者还展示了显微镜图像的分割效果并在挑战赛中取得了大胜。

Network Architecture

网络结构如图1所示,包括了左侧的缩小路径和右侧的放大路径。缩小路径是一个典型的卷积网络,其包含重复了两个连续 3 × 3 3\times 3 3×3的卷积层(不进行padding),紧接着是一个rectified linear unit(ReLU)以及一个步长为2的 2 × 2 2\times 2 2×2的max-pooling层。每个降采样操作后的卷积层输出通道数是出入的2倍。方法路径中的每一步包含一个上采样层和一个向上卷积层(应该是转置卷积层),每个向上卷积层的输出是输入特征的一半。从图1中还可以看出,缩小路径中的特征会进行crop后与放大路径的同级特征连接在一起,再经过两个 3 × 3 3\times 3 3×3的卷积层和一个ReLU。Crop是必须的,因为每次卷积后边界像素都是丢失。在最后一层,使用 1x1 卷积将64 个特征向量映射到输出,输出的通道数与分割图像中类别数量一致。网络共有23个卷积层。需要注意,输入的图像尺寸要使得每次下采样时的像素的宽高是偶数,不然下采样时会有像素覆盖不到。

Training

U-Net在Caffe框架下使用随机梯度下降法训练,由于使用了没有padding的卷积层,输出的分割图像会小于输入图像。为了最小化间接开销和最大化地使用显存,作者采用输入大的图像尺寸而不是设置大的batch size,因此本文的batch size为1,即单张图像。基于此,作者设置了较高的动量(0.99)以使用先前的训练样本来确定当前批次的参数更新。

能量函数由最后的输出进行逐像素进行soft-max得到,损失函数为交叉熵损失。Soft-max定义为 p k ( x ) = exp ⁡ ( a k ( x ) ) / ( ∑ k ′ = 1 K exp ⁡ ( a k ′ ( x ) ) ) p_{k}(\mathbf{x})=\exp \left(a_{k}(\mathbf{x})\right) /\left(\sum_{k^{\prime}=1}^{K} \exp \left(a_{k^{\prime}}(\mathbf{x})\right)\right) pk(x)=exp(ak(x))/(k=1Kexp(ak(x))),其中 a k ( x ) a_{k}(\mathbf{x}) ak(x)表示特征在 k k k通道 x x x位置的激活值, x ∈ Ω \mathbf{x} \in \Omega xΩ Ω ⊂ Z 2 \Omega \subset \mathbb{Z}^{2} ΩZ2 K K K表示网络输出的类别数, p k ( x ) p_k(x) pk(x)为近似的最大函数。换句话说, p k ( x ) ≈ 1 p_{k}(\mathbf{x}) \approx 1 pk(x)1对应的是拥有类别 k k k最大响应的 a k ( x ) a_k(x) ak(x) p k ( x ) ≈ 0 p_{k}(\mathbf{x}) \approx 0 pk(x)0对应的是类别 k k k其他的响应值。交叉熵损失计算每个像素位置上 p ℓ ( x ) ( x ) p_{\ell(\mathbf{x})}(\mathbf{x}) p(x)(x) 1 1 1的差距 E = ∑ x ∈ Ω w ( x ) log ⁡ ( p ℓ ( x ) ( x ) ) E=\sum_{\mathbf{x} \in \Omega} w(\mathbf{x}) \log \left(p_{\ell(\mathbf{x})}(\mathbf{x})\right) E=xΩw(x)log(p(x)(x)),其中 ℓ : Ω → { 1 , … , K } \ell: \Omega \rightarrow\{1, \ldots, K\} :Ω{1,,K}是每个像素真实的label, w : Ω → R w: \Omega \rightarrow \mathbb{R} w:ΩR是上文提到的weight map。

作者先根据真实的分割图预先计算了weight map来补偿训练数据集中某特定类别的不同像素频率,因此得以强制让网络学习相接触分割对象之间的分离边界。分离边界通过形态学操作计算得到,weight map可以通过以下方式计算 w ( x ) = w c ( x ) + w 0 ⋅ exp ⁡ ( − ( d 1 ( x ) + d 2 ( x ) ) 2 2 σ 2 ) w(\mathbf{x})=w_{c}(\mathbf{x})+w_{0} \cdot \exp \left(-\frac{\left(d_{1}(\mathbf{x})+d_{2}(\mathbf{x})\right)^{2}}{2 \sigma^{2}}\right) w(x)=wc(x)+w0exp(2σ2(d1(x)+d2(x))2),其中 w c : Ω → R w_{c}: \Omega \rightarrow \mathbb{R} wc:ΩR用于平衡类别频率, d 1 : Ω → R d_{1}: \Omega \rightarrow \mathbb{R} d1:ΩR表示到最近细胞边界的距离, d 2 : Ω → R d_{2}: \Omega \rightarrow \mathbb{R} d2:ΩR表示到第二近细胞边界的距离。在实验中,作者设置 set ⁡ w 0 = 10 \operatorname{set} w_{0}=10 setw0=10 σ ≈ 5 \sigma \approx 5 σ5个像素。

在带有卷积层与不同路径的网络中,一个好的初始化权重是非常重要的。否则,网络的某些部分可能会提供过多的激活,而其他部分则永远不会做出贡献(这句是真看不明白)。理想的权重初始化方法是将权重调整为具有相似的单位方差。对于本文的网络,初始化权重可以由高斯分布得到,标准差为 2 / N \sqrt{2 / N} 2/N N N N是输入神经元的节点数。比如,对于输入层的64通道的 3 ∗ 3 3*3 33的卷积核来说 N = 9 ⋅ 64 = 576 N=9 \cdot 64=576 N=964=576

下面贴个粗糙实现(有问题请指正)

import torch.nn as nn
import torch
# from torchsummary import summary

class ConvBlock(nn.Module):
    def __init__(self, channelList):
        super(ConvBlock, self).__init__()
        self.conv2d_block = nn.Sequential(
            nn.Conv2d(channelList[0], channelList[1], 3),
            nn.ReLU(),
            nn.Conv2d(channelList[1], channelList[2], 3),
            nn.ReLU()
        )
    def forward(self, x):
        return self.conv2d_block(x)

class Unet(nn.Module):
    def __init__(self,EnChannelsList, DeChannelsList):
        super(Unet, self).__init__()
        self.conv2d_block1 = ConvBlock(EnChannelsList[0])
        self.conv2d_block2 = ConvBlock(EnChannelsList[1])
        self.conv2d_block3 = ConvBlock(EnChannelsList[2])
        self.conv2d_block4 = ConvBlock(EnChannelsList[3])
        self.conv2d_block5 = ConvBlock(EnChannelsList[4])
        self.deconv2d_block1 = ConvBlock(DeChannelsList[0])
        self.deconv2d_block2 = ConvBlock(DeChannelsList[1])
        self.deconv2d_block3 = ConvBlock(DeChannelsList[2])
        self.deconv2d_block4 = ConvBlock(DeChannelsList[3])
        self.maxpool = nn.MaxPool2d(2)
        self.deconv2d1 = nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1)
        self.deconv2d2 = nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1)
        self.deconv2d3 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
        self.deconv2d4 = nn.ConvTranspose2d(128, 128, 4, stride=2, padding=1)
        self.conv_last = nn.Sequential(
            nn.Conv2d(64, 2, 1),
            nn.Sigmoid()
        )
        
    
    def forward(self, x):
        # encoder
        conv1 = self.conv2d_block1(x)
        maxpool1 = self.maxpool(conv1)
        conv2 = self.conv2d_block2(maxpool1)
        maxpool2 = self.maxpool(conv2)
        conv3 = self.conv2d_block3(maxpool2)
        maxpool3 = self.maxpool(conv3)
        conv4 = self.conv2d_block4(maxpool3)
        maxpool4 = self.maxpool(conv4)
        conv5 = self.conv2d_block5(maxpool4)
        # decoder
        uppool1 = self.deconv2d1(conv5)
        cat1 = torch.cat([uppool1,conv4[:,:,4:conv4.shape[2]-4,4:conv4.shape[3]-4]], dim=1)
        deconv1 = self.deconv2d_block1(cat1)
        uppool2 = self.deconv2d2(deconv1)
        cat2 = torch.cat([uppool2,conv3[:,:,16:conv3.shape[2]-16,16:conv3.shape[3]-16]], dim=1)
        deconv2 = self.deconv2d_block2(cat2)
        uppool3 = self.deconv2d3(deconv2)
        cat3 = torch.cat([uppool3,conv2[:,:,40:conv2.shape[2]-40,40:conv2.shape[3]-40]], dim=1)
        deconv3 = self.deconv2d_block3(cat3)
        uppool4 = self.deconv2d4(deconv3)
        cat4 = torch.cat([uppool4,conv1[:,:,88:conv1.shape[2]-88,88:conv1.shape[3]-88]], dim=1)
        deconv4 = self.deconv2d_block4(cat4)
        out = self.conv_last(deconv4)
        return out

EnChannelsList = [
    [1, 64, 64],
    [64, 128, 128],
    [128, 256, 256],
    [256, 512, 512],
    [512, 1024, 1024]
]

DeChannelsList = [
    [1024, 512, 512],
    [512, 256, 256],
    [256, 128, 128],
    [192, 64, 64],
]

model = Unet(EnChannelsList, DeChannelsList).cuda()
# summary(model, (1, 572, 572))
  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值