UNet++学习笔记

序言

本文整理于作者知乎原文:研习UNet,UNet++的作者在知乎讲的非常仔细,感兴趣的可以直接去围观,这里只是为了方便记忆做个总结。

正文

一、图像分割背景

在计算机视觉领域,全卷积网络(FCN)是比较有名的图像分割网络,医学图像处理方向,U-Net可以说是一个更加炙手可热的网络,基本上所有的分割问题,我们都会拿U-Net先看一下基本的结果,然后进行“魔改”。
在这里插入图片描述
UNet和FCN对比:

  1. 两者均是基于encoder-deconder,发表时间都是2015年,UNet稍晚;
  2. UNet网络结构完全对称,FCN的decoder相对简单,只用了一个deconvolution的操作;
  3. skip connection区别,FCN用的是加操作(summation),UNet用的是叠操作(concatenation)。

图像分割的思想:

输入一张图片,经过下采样编码后,得到一串比原先更小的特征,相当于压缩,在经过一个解码,理想状态是还原原来的图像。继续简化就是,一幅图,编码,或者说降采样,然后解码,也就是上采样,然后输出一个分割结果。

UNet网络结构:

  • 下采样
  • 上采样
  • skip connection
    在这里插入图片描述

由UNet联想到的问题:

  1. Unet这个三年不动的拓扑结构真的一点儿毛病都没有吗?

这个结构最优秀的点就在于结合了深浅层特征,而且结合的方式优于FCN,很多模型都是基于这个基础上做改进。

  1. Unet要多深合适?

并不是所有的问题增加深度就能提高性能。有些问题简单,浅层网络就能解决,增加深度并不能增加效果。有些问题难,就需要深层网络。

  1. 降采样对于分割网络到底是不是必须的?

它可以增加对输入图像的一些小扰动的鲁棒性,比如图像平移,旋转等,减少过拟合的风险,降低运算量,和增加感受野的大小。

  1. 所抓取的特征都很重要,为什么我非要降16倍了才开始上采样回去呢?

并不是非要降低16倍,而是根据自身问题的难度选择降低的倍数。更简明的说,就是你选取多少层,这个你得试试。

二、UNet++

设计思路:

  1. 对UNet四个下采样层的每个下采样后的特征图都进行上采样,针对于不同的任务找到适合的,所以UNet可以变为多个UNet结构

在这里插入图片描述

  1. 但是总不能每个结构都要训练一次,所以把以上结构结合起来,共用一个提取器,就是把1~4层的U-Net全给连一起了:
    在这里插入图片描述
  2. 但是这个网络是不能被训练的,在反向传播的时候中间的红色区域没有梯度
    在这里插入图片描述
  3. 强行加上梯度,增加了短链接,去掉了本身的长连接:
    在这里插入图片描述
  4. 但是作者认为U-Net中的长连接还是有必要的,它联系了输入图像的很多信息,它和残差的操作非常类似,也就是residual操作,x+f(x),所以再加上长连接:
    在这里插入图片描述

最后UNet++的基本结构就是这样的了,为什么不一步到胃直接贴上这张结构图,其实我们从作者的设计思路来走一遍,才能对这张图有更好的认识。

思考:

Unet++网络比U-Net效果好,但是这个网络增加了多少的参数,加粗的参数可都是比U-Net多出来的啊?是不是通过增加参数就能达到Unet++的能力?看下图:
在这里插入图片描述
上图说明:

作者设计了一个叫wide U-Net的参考结构,先来看看UNet++的参数数量是9.04M,而U-Net是7.76M,多了差不多16%的参数,所以wide U-Net我们在设计时就让它的参数比UNet++差不多,并且还稍微多一点点,来证明并不是无脑增加参数量,模型效果就会好。(这部分在知乎那篇文章有详细解释,这里只搬来了结论)

三、UNet++ 训练

重新温习一下UNet++网络结构:
在这里插入图片描述
问题:

如果只用最右边的一个loss来做的话,在反向传播的时候中间部分会收不到过来的梯度。

解决:

加入深监督,也就是deep supervision,具体的实现操作就是在图中X01,X02,X03 ,X04后面加一个1x1的卷积核,相当于去监督每个level,或者每个分支的U-Net的输出。
在这里插入图片描述
思考:在训练过程中在各个level的子网络中加了这种深监督,可以带来怎样的好处呢?

答案:剪枝!

  1. 为什么UNet++可以被剪枝?
  2. 如何剪枝?
  3. 剪枝好处在哪里?
    在这里插入图片描述
1. 为什么可以剪枝?

看上图,在测试的阶段,由于输入的图像只会前向传播,扔掉L4这部分对前面的输出完全没有影响的,而在训练阶段,因为既有前向,又有反向传播,被剪掉的部分是会帮助其他部分做权重更新的。也就是说测试时,剪掉部分对剩余结构不做影响,训练时,剪掉部分对剩余部分有影响。

理解:

因为在深监督的过程中,每个子网络的输出都其实已经是图像的分割结果了,所以如果小的子网络的输出结果已经足够好了,我们可以随意的剪掉那些多余的部分了。

我们把每个剪完剩下的子网络根据它们的深度命名为UNet++ L1,L2,L3,L4,后面会简称为L1~L4。最理想的状态是什么?当然是L1喽,如果L1的输出结果足够好,剪完以后的分割网络会变得非常的小。

2. 如何去剪多少?

因为在训练模型的时候会把数据分为训练集,验证集和测试集,训练集上是一定拟合的很好的,测试集是我们不能碰的,所以我们会根据子网络在验证集的结果来决定剪多少。所谓的验证集就是一开始从训练集中分出来的数据,用来监测训练过程用的。
在这里插入图片描述
图片解释:

先看看L1~L4的网络参数量,差了好多,L1只有0.1M,而L4有9M,也就是理论上如果L1的结果我是满意的,那么模型可以被剪掉的参数达到98.8%。不过根据我们的四个数据集,L1的效果并不会那么好,因为太浅了嘛。但是其中有三个数据集显示L2的结果和L4已经非常接近了,也就是说对于这三个数据集,在测试阶段,我们不需要用9M的网络,用半M的网络足够了。

3. 剪枝的好处

剪枝应用最多的就是在移动手机端了,根据模型的参数量,如果L2得到的效果和L4相近,模型的内存可以省18倍,大大减少了模型的参数量。

总结:

UNet++的第一个优势就是精度的提升,这个应该它整合了不同层次的特征所带来的,第二个是灵活的网络结构配合深监督,让参数量巨大的深度网络在可接受的精度范围内大幅度的缩减参数量。

四、代码

代码来源:Pytorch实现UNet++

import torch
import torch.nn as nn


class ConvSamePad2d(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int, bias: bool = True):
        super().__init__()

        left_top_pad = right_bottom_pad = kernel_size // 2
        if kernel_size % 2 == 0:
            right_bottom_pad -= 1

        self.layer = nn.Sequential(
            nn.ReflectionPad2d((left_top_pad, right_bottom_pad, left_top_pad, right_bottom_pad)),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias)
        )

    def forward(self, inputs):
        return self.layer(inputs)


class Conv3x3(nn.Module):
    def __init__(self, in_channels, out_channels, drop_rate=0.5):
        super().__init__()
        self.layer = nn.Sequential(
            ConvSamePad2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
            nn.Dropout2d(p=drop_rate),
            ConvSamePad2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3),
            nn.Dropout2d(p=drop_rate)
        )

    def forward(self, inputs):
        return self.layer(inputs)


class Conv1x1(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.layer = nn.Sequential(
            ConvSamePad2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        return self.layer(inputs)


class Unet(nn.Module):
    def __init__(self, in_channels, n_classes, deep_supervision=True):
        super().__init__()
        self.deep_supervision = deep_supervision
        
        filters = [32, 64, 128, 256, 512]

        # j == 0
        self.x_00 = Conv3x3(in_channels=in_channels, out_channels=filters[0])
        self.pool0 = nn.MaxPool2d(kernel_size=2)

        self.x_01 = Conv3x3(in_channels=filters[0] * 2, out_channels=filters[0])
        self.x_02 = Conv3x3(in_channels=filters[0] * 3, out_channels=filters[0])
        self.x_03 = Conv3x3(in_channels=filters[0] * 4, out_channels=filters[0])
        self.x_04 = Conv3x3(in_channels=filters[0] * 5, out_channels=filters[0])

        self.up_10_to_01 = nn.ConvTranspose2d(in_channels=filters[1], out_channels=filters[0], kernel_size=2, stride=2)
        self.up_11_to_02 = nn.ConvTranspose2d(in_channels=filters[1], out_channels=filters[0], kernel_size=2, stride=2)
        self.up_12_to_03 = nn.ConvTranspose2d(in_channels=filters[1], out_channels=filters[0], kernel_size=2, stride=2)
        self.up_13_to_04 = nn.ConvTranspose2d(in_channels=filters[1], out_channels=filters[0], kernel_size=2, stride=2)


        # j == 1
        self.x_10 = Conv3x3(in_channels=filters[0], out_channels=filters[1])
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.x_11 = Conv3x3(in_channels=filters[1] * 2, out_channels=filters[1])
        self.x_12 = Conv3x3(in_channels=filters[1] * 3, out_channels=filters[1])
        self.x_13 = Conv3x3(in_channels=filters[1] * 4, out_channels=filters[1])

        self.up_20_to_11 = nn.ConvTranspose2d(in_channels=filters[2], out_channels=filters[1], kernel_size=2, stride=2)
        self.up_21_to_12 = nn.ConvTranspose2d(in_channels=filters[2], out_channels=filters[1], kernel_size=2, stride=2)
        self.up_22_to_13 = nn.ConvTranspose2d(in_channels=filters[2], out_channels=filters[1], kernel_size=2, stride=2)


        # j == 2
        self.x_20 = Conv3x3(in_channels=filters[1], out_channels=filters[2])
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.x_21 = Conv3x3(in_channels=filters[2] * 2, out_channels=filters[2])
        self.x_22 = Conv3x3(in_channels=filters[2] * 3, out_channels=filters[2])

        self.up_30_to_21 = nn.ConvTranspose2d(in_channels=filters[3], out_channels=filters[2], kernel_size=2, stride=2)
        self.up_31_to_22 = nn.ConvTranspose2d(in_channels=filters[3], out_channels=filters[2], kernel_size=2, stride=2)


        # j == 3
        self.x_30 = Conv3x3(in_channels=filters[2], out_channels=filters[3])
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.x_31 = Conv3x3(in_channels=filters[3] * 2, out_channels=filters[3])

        self.up_40_to_31 = nn.ConvTranspose2d(in_channels=filters[4], out_channels=filters[3], kernel_size=2, stride=2)


        # j == 4
        self.x_40 = Conv3x3(in_channels=filters[3], out_channels=filters[4])


        # 1x1 conv layer
        self.final_1x1_x01 = Conv1x1(in_channels=filters[0], out_channels=n_classes)
        self.final_1x1_x02 = Conv1x1(in_channels=filters[0], out_channels=n_classes)
        self.final_1x1_x03 = Conv1x1(in_channels=filters[0], out_channels=n_classes)
        self.final_1x1_x04 = Conv1x1(in_channels=filters[0], out_channels=n_classes)

    def forward(self, inputs, L=4):
        if not (1 <= L <= 4):
            raise ValueError("the model pruning factor `L` should be 1 <= L <= 4")

        x_00_output = self.x_00(inputs)
        x_10_output = self.x_10(self.pool0(x_00_output))
        x_10_up_sample = self.up_10_to_01(x_10_output)
        x_01_output = self.x_01(torch.cat([x_00_output, x_10_up_sample], 1))
        nestnet_output_1 = self.final_1x1_x01(x_01_output)

        if L == 1:
            return nestnet_output_1

        x_20_output = self.x_20(self.pool1(x_10_output))
        x_20_up_sample = self.up_20_to_11(x_20_output)
        x_11_output = self.x_11(torch.cat([x_10_output, x_20_up_sample], 1))
        x_11_up_sample = self.up_11_to_02(x_11_output)
        x_02_output = self.x_02(torch.cat([x_00_output, x_01_output, x_11_up_sample], 1))
        nestnet_output_2 = self.final_1x1_x01(x_02_output)

        if L == 2:
            if self.deep_supervision:
                # return the average of output layers
                return (nestnet_output_1 + nestnet_output_2) / 2
            else:
                return nestnet_output_2

        x_30_output = self.x_30(self.pool2(x_20_output))
        x_30_up_sample = self.up_30_to_21(x_30_output)
        x_21_output = self.x_21(torch.cat([x_20_output, x_30_up_sample], 1))
        x_21_up_sample = self.up_21_to_12(x_21_output)
        x_12_output = self.x_12(torch.cat([x_10_output, x_11_output, x_21_up_sample], 1))
        x_12_up_sample = self.up_12_to_03(x_12_output)
        x_03_output = self.x_03(torch.cat([x_00_output, x_01_output, x_02_output, x_12_up_sample], 1))
        nestnet_output_3 = self.final_1x1_x01(x_03_output)

        if L == 3:
            # return the average of output layers
            if self.deep_supervision:
                return (nestnet_output_1 + nestnet_output_2 + nestnet_output_3) / 3
            else:
                return nestnet_output_3

        x_40_output = self.x_40(self.pool3(x_30_output))
        x_40_up_sample = self.up_40_to_31(x_40_output)
        x_31_output = self.x_31(torch.cat([x_30_output, x_40_up_sample], 1))
        x_31_up_sample = self.up_31_to_22(x_31_output)
        x_22_output = self.x_22(torch.cat([x_20_output, x_21_output, x_31_up_sample], 1))
        x_22_up_sample = self.up_22_to_13(x_22_output)
        x_13_output = self.x_13(torch.cat([x_10_output, x_11_output, x_12_output, x_22_up_sample], 1))
        x_13_up_sample = self.up_13_to_04(x_13_output)
        x_04_output = self.x_04(torch.cat([x_00_output, x_01_output, x_02_output, x_03_output, x_13_up_sample], 1))
        nestnet_output_4 = self.final_1x1_x01(x_04_output)

        if L == 4:
            if self.deep_supervision:
                # return the average of output layers
                return (nestnet_output_1 + nestnet_output_2 + nestnet_output_3 + nestnet_output_4) / 4
            else:
                return nestnet_output_4


if __name__ == '__main__':
    inputs = torch.rand((3, 1, 96, 96)).cuda()

    unet_plus_plus = Unet(in_channels=1, n_classes=3).cuda()

    from datetime import datetime

    st = datetime.now()
    output = unet_plus_plus(inputs, L=1)
    print(f"{(datetime.now() - st).total_seconds(): .4f}s")

训练自己的数据集可以通过搭建自己的Unet语义分割平台来实现。首先,你需要准备自己的数据集,包括训练集和验证集。可以参考引用中的博客文章和引用中的代码,根据自己的数据集生成相应的txt文件。txt文件的内容是模型训练和测试过程中读入图像数据的名称。 接下来,你可以使用PyTorch来搭建Unet语义分割模型。可以参考引用中的博客文章,其中介绍了如何使用PyTorch搭建自己的Unet语义分割平台。你可以根据自己的需求进行模型的修改和优化。 在训练过程中,你可以使用自己准备的数据集进行模型训练。可以参考引用中的博客文章和引用中的代码,利用生成的txt文件读取图像数据并进行训练。 训练自己的数据集需要一定的时间和计算资源,同时还需要进行适当的参数调整和优化。建议在训练过程中监控模型的性能指标,如损失函数和准确率,并根据需要进行调整和改进。 通过以上步骤,你就可以训练自己的数据集并应用Unet模型进行语义分割任务了。祝你成功!<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [学习笔记Unet学习及训练自己的数据集](https://blog.csdn.net/Qingkaii/article/details/124474485)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 50%"] - *2* *3* [SwinUnet官方代码训练自己数据集(单通道灰度图像的分割)](https://blog.csdn.net/qq_37652891/article/details/123932772)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值