【论文笔记】医学图像分割 U-Net++:A Nested U-Net Architecture

1 综述

今天分享一篇2018年的论文《UNet++: A Nested U-Net Architecture for Medical Image Segmentation》,已经有很多博客将其解读的很详细了,这里不再重复。这里主要提一下论文亮点:

简而言之,文章主要对 U-Net 中的 plain skip cennections 进行修改,作者认为Encoder 和 Decoder 的不同语义之间直接连接效果并不好,提出了嵌套的和稠密的跳跃连接来减小不同特征图之间的语义差距,达到改善分割效果的目的,并用了深监督进行训练,训练后可进行模型剪枝。

论文原文:The main idea behind UNet++ is to bridge the semantic gap between the feature maps of the encoder and decoder prior to fusion。

U-Net++作者周纵苇解读:U-net++解读

论文地址:
《UNet++: A Nested U-Net Architecture for Medical Image Segmentation》
《UNet++: Redesigning Skip Connections to Exploit Multiscale Features in Image Segmentation》

代码地址:UNetPlusPlus

2 网络结构

下图为论文中给出的U-Net++网络结构,清晰明了:

黑色部分:是原始 U-Net 网络结构;

绿色和蓝色部分:是稠密连接( nested and dense skip connections ),改善原始 U-Net 结构中的 plain skip connections;

红色部分:因添加稠密连接后,在计算 loss function 时,梯度无法经过绿色和蓝色 bolck 区域,所以添加了红色部分进行深监督,实现训练;同时也便于后期进行模型剪枝;

不同深度网络:U-Net++因自身结构的原因,具有不同的网络深度,如L1、L2、L3 和 L4;

在这里插入图片描述

2.1 skip connection

不同与U-Net 网络中直接进行skip connection,U-Net++是嵌套的和稠密的跳跃连接来实现的,作者认为当解码和编码的特征图的语义相似进行融合时,它学习的效果会更好

在 UNet 中高分辨率特征图快速直接地从编码到解码,结果是不相似语义的特征图之间进行融合。与 UNet 的朴素跳跃连接不同,U-Net++ 将高分辨率特征图从 Encoder 网络逐渐地和 Decoder 网络中相应语义的特征图优先进行融合,这个网络它可以更高效地捕获前景对象的深层 ( fine-grained ) 细节。

2.2 deep supervision

监督每个分支的U-Net的输出,这样可以解决中间部分无法训练的问题,具体如下:

(1)在图中 X(0,1)、X(0,2)、X(0,3)、X(0,4) 后面加一个1x1的卷积核,将 feature map 的channel 数量变换到与 output_channel 数量一致,以 Dice + Cross Entropy 作为损失函数,来进行训练;

(2)实际训练中,对于不同的loss分支,作者的给出的权重为1:1:1:1;

(3)两种模式

精确模式:将输出的所有分割分支进行平均,得到最终分割结果;

快速模型:将得到的4个分割图,只选择其中一个分支,这个选择决定了模型修剪的程度和速度增益;

2.3 模型剪枝

作者是在测试阶段,在测试集上进行剪枝的;

作者解释在测试阶段,由于输入的图像只会前向传播,扔掉这部分对前面的输出完全没有的;而在训练阶段,因为既有前向,又有后向传播,被剪掉的部分是会帮助其他部分做权重更新的。因此测试时,剪掉部分对剩余结构不做影响,训练时,剪掉的分对剩余部分有影响;

论文中给出了对于不同的数据集,剪枝结果,不同程度的剪枝可减小不同程度的参数量,剪枝越多则参数量越少,但模型性能会退化,具体剪枝情况由不同数据集而异;此处图中,肺结节在L2就有较好的结果了;
在这里插入图片描述

3 分割结果对比

3.1 论文用到数据集

在这里插入图片描述

3.2 预测结果

论文中给出了在4种不同数据集上进行测试的结果:

(1)其中 U-Net 为 benchmark, wide U-Net 是加宽后的U-Net结构, 用于单纯增加网络参数量,便于控制变量进行对比;

(2)Table 3 是不同网络的对比,包括总参数大小、IOU系数;
在这里插入图片描述
在这里插入图片描述

4 源码解析

此处展示的是原始 U-Net++ 结构代码,对照网络结构看起来更易理解,存放在helper_functions.py中。代码是2D的,对于3D 图像,添加 depth 即可,同理;

作者也提供了在其他 backbone 进行U Net++化的代码,此处就不详细列举了;

########################################
# 2D Standard
########################################

def standard_unit(input_tensor, stage, nb_filter, kernel_size=3):

    x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(input_tensor)
    x = Dropout(dropout_rate, name='dp'+stage+'_1')(x)
    x = Conv2D(nb_filter, (kernel_size, kernel_size), activation=act, name='conv'+stage+'_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(x)
    x = Dropout(dropout_rate, name='dp'+stage+'_2')(x)

    return x


Standard UNet++ [Zhou et.al, 2018]
Total params: 9,041,601
"""
def UNetPlusPlus(img_rows, img_cols, color_type=1, num_class=1, deep_supervision=False):

    nb_filter = [32,64,128,256,512]

    # Handle Dimension Ordering for different backends
    global bn_axis
    if K.image_dim_ordering() == 'tf':
      bn_axis = 3
      img_input = Input(shape=(img_rows, img_cols, color_type), name='main_input')
    else:
      bn_axis = 1
      img_input = Input(shape=(color_type, img_rows, img_cols), name='main_input')

    conv1_1 = standard_unit(img_input, stage='11', nb_filter=nb_filter[0])
    pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(conv1_1)

    conv2_1 = standard_unit(pool1, stage='21', nb_filter=nb_filter[1])
    pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(conv2_1)

    up1_2 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up12', padding='same')(conv2_1)
    conv1_2 = concatenate([up1_2, conv1_1], name='merge12', axis=bn_axis)
    conv1_2 = standard_unit(conv1_2, stage='12', nb_filter=nb_filter[0])

    conv3_1 = standard_unit(pool2, stage='31', nb_filter=nb_filter[2])
    pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(conv3_1)

    up2_2 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up22', padding='same')(conv3_1)
    conv2_2 = concatenate([up2_2, conv2_1], name='merge22', axis=bn_axis)
    conv2_2 = standard_unit(conv2_2, stage='22', nb_filter=nb_filter[1])

    up1_3 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up13', padding='same')(conv2_2)
    conv1_3 = concatenate([up1_3, conv1_1, conv1_2], name='merge13', axis=bn_axis)
    conv1_3 = standard_unit(conv1_3, stage='13', nb_filter=nb_filter[0])

    conv4_1 = standard_unit(pool3, stage='41', nb_filter=nb_filter[3])
    pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='pool4')(conv4_1)

    up3_2 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up32', padding='same')(conv4_1)
    conv3_2 = concatenate([up3_2, conv3_1], name='merge32', axis=bn_axis)
    conv3_2 = standard_unit(conv3_2, stage='32', nb_filter=nb_filter[2])

    up2_3 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up23', padding='same')(conv3_2)
    conv2_3 = concatenate([up2_3, conv2_1, conv2_2], name='merge23', axis=bn_axis)
    conv2_3 = standard_unit(conv2_3, stage='23', nb_filter=nb_filter[1])

    up1_4 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up14', padding='same')(conv2_3)
    conv1_4 = concatenate([up1_4, conv1_1, conv1_2, conv1_3], name='merge14', axis=bn_axis)
    conv1_4 = standard_unit(conv1_4, stage='14', nb_filter=nb_filter[0])

    conv5_1 = standard_unit(pool4, stage='51', nb_filter=nb_filter[4])

    up4_2 = Conv2DTranspose(nb_filter[3], (2, 2), strides=(2, 2), name='up42', padding='same')(conv5_1)
    conv4_2 = concatenate([up4_2, conv4_1], name='merge42', axis=bn_axis)
    conv4_2 = standard_unit(conv4_2, stage='42', nb_filter=nb_filter[3])

    up3_3 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2)
    conv3_3 = concatenate([up3_3, conv3_1, conv3_2], name='merge33', axis=bn_axis)
    conv3_3 = standard_unit(conv3_3, stage='33', nb_filter=nb_filter[2])

    up2_4 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3)
    conv2_4 = concatenate([up2_4, conv2_1, conv2_2, conv2_3], name='merge24', axis=bn_axis)
    conv2_4 = standard_unit(conv2_4, stage='24', nb_filter=nb_filter[1])

    up1_5 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4)
    conv1_5 = concatenate([up1_5, conv1_1, conv1_2, conv1_3, conv1_4], name='merge15', axis=bn_axis)
    conv1_5 = standard_unit(conv1_5, stage='15', nb_filter=nb_filter[0])

    nestnet_output_1 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_1', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_2)
    nestnet_output_2 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_2', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_3)
    nestnet_output_3 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_3', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_4)
    nestnet_output_4 = Conv2D(num_class, (1, 1), activation='sigmoid', name='output_4', kernel_initializer = 'he_normal', padding='same', kernel_regularizer=l2(1e-4))(conv1_5)

    if deep_supervision:
    	# 计算Loss时,4个分支权重为1:1:1:1
        model = Model(input=img_input, output=[nestnet_output_1,
                                               nestnet_output_2,
                                               nestnet_output_3,
                                               nestnet_output_4])
    else:
        model = Model(input=img_input, output=[nestnet_output_4])

    return model
    
  • 1
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值