【医学分割】u2net

概述

u2net是在unet的基础上提出来的,是一个效果非常棒的显著性目标检测模型。
显著性目标检测:分割出图像的主体。
模型提出的主要背景是两部分:
1、分割任务的backbone主要是一些预训练模型。因为一些分割backbone的效果没有预训练模型的效果好,所以大家也都采用预训练模型了。问题:他么并不是为了分割而设计的,对于分割任务中比较关键的局部细节和全局对比信息关注不够,所以想要更加好的适应分割还需要在这个基础上加一些特殊的结构,实现对于提取的特征的更好的利用,这就带来了计算复杂度。
2、模型一般比较深,考虑到显存和计算的开销,会在最前面的一些层中对图片先做一些下采样操作,降低分辨率。问题:其实对于分割任务而言,高分辨率图像中丰富的空间信息是需要好好利用的。

u2net成功的解决了这两个问题,提出了两层嵌套的u形结构,可以从头训练网络而不依赖于预训练模型,特征提取的效果和预训练模型一样棒,并且不需要额外的结构处理这些特征。另外它在网络变深之后,还能保持高分辨率(因为RSU中大量的池化操作),并且带来的的显存和计算量的开销也是非常友好的。

细节

网络结构

以下就是u2net的网络结构,总体还是unet的u形结构,但是其中的每层或者每个stage由unet中朴素的卷积结构变为了RUS(Residual U-blocks),然后每个decoder的侧边输出都收到gt的监督,所有的侧边输出做concat之后,在进行卷积操作就是最终的网络输出了。
之前的研究还有堆叠或者级联unet得到 u ∗ n − n e t u*n-net unnet,但是作者是嵌套或者指数表示unuet去了,显然嵌套的数量可以很大,即 u n n e t u^n net unnet,但是考虑实际情况,还是嵌套一层得到 u 2 n e t u^2 net u2net
然后每个RSU的层数随着encoder的层数的增加而减少,即En_1、En_2、En_3、En_4使用的分别是RSU-7、RSU-6、RSU-5、RSU-4,因为我们重视对于高分辨率特征图的特征提取,同时会进行池化操作,降低尺寸。而En_5、En_6采用的就是RSU-4F了,F表示不会变化尺寸,也就是只进行特征提取。
在这里插入图片描述

RSU(Residual U-blocks)

RUS替换了unet中朴素的卷积块,他能够更好的捕捉全局和局部的信息,而以往的1x1,3x3的卷积由于感受野的缘故,往往擅长捕捉局部的信息,对于全局信息的捕捉没有那么良好,而全局信息往往是分割分割所需要的。RUS通过这个u形状结构实现不同尺度不同感受野的特征图的混合,能够捕捉来自更多的不同尺度的全局信息。
在这里插入图片描述

并且他还使用了残差的思想。resnet中需要至少两层才能做恒等映射,不然就是做线性变换,而n次的线性变化效果等价与1次的线性变换。而本文中,由于这个u-block包含了若干层了,所以跨越一个block就行了。
在这里插入图片描述
然后是计算量方面,作者将一些主流的块结构进行了对比,发现虽然RSU的计算量随着深度的增加是线性的,但是系数很小,故计算量其实没有怎么大,是可以堆叠的很深的。
在这里插入图片描述

损失

主要就是两部分,一部分是侧边输出特征图的损失,另一部分是这些侧边输出融合之后形成的最终输出特征图的损失。
在这里插入图片描述

在这里插入图片描述

简单实现

import paddle
import paddle.nn as nn

# 卷积块
# 参数dilation 默认是1 即标准卷积,设置为2 表示膨胀卷积或者说是空洞卷积
class ConvalutionBlock(nn.Layer):
    def __init__(self,in_channels,out_channels,dilation=1):
        super(ConvalutionBlock, self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2D(in_channels,out_channels,3,1,1,dilation),
            nn.BatchNorm2D(out_channels),
            nn.LeakyReLU()
        )

    def forward(self,x):
        return self.layer(x)

# 下采样块
class DownBlock(nn.Layer):
    def __init__(self,in_channels,out_channels):
        super(DownBlock, self).__init__()
        self.layer=nn.Sequential(
            nn.MaxPool2D(1),
            ConvalutionBlock(in_channels,out_channels)
        )
    def forward(self,x):
        return self.layer(x)

# 上采样块
# 将decoder当前层上采样并且和encoder当前层做concat
# 然后再经过3x3的卷积块
class UpBlock(nn.Layer):
    def __init__(self,in_channels,out_channels):
        super(UpBlock, self).__init__()
        self.upsample=nn.UpsamplingBilinear2D(scale_factor=2)
        self.conv=ConvalutionBlock(in_channels,out_channels)

    def forward(self, x1, x2):
        x1 = self.upsample(x1)
        # 因为tensor是ncwh的 我们需要在c维度上concat 所以axis是1
        x=paddle.concat([x2, x1], axis=1)
        return self.conv(x)

# 第一个stage的RSU或者说是作者论文中说的7层的RSU
class RSU1(nn.Layer):
    def __init__(self,in_channels,mid_channels,out_channels):
        super(RSU1, self).__init__()
        self.pool=nn.MaxPool2D(2)

        self.first_layer=ConvalutionBlock(in_channels,out_channels)

        self.en_1=ConvalutionBlock(out_channels,mid_channels)
        self.en_2=DownBlock(mid_channels,mid_channels)
        self.en_3=DownBlock(mid_channels,mid_channels)
        self.en_4=DownBlock(mid_channels,mid_channels)
        self.en_5=DownBlock(mid_channels,mid_channels)
        self.en_6=DownBlock(mid_channels,mid_channels)
        self.en_7=ConvalutionBlock(mid_channels,mid_channels,2)

        self.de_6=ConvalutionBlock(mid_channels*2,mid_channels)
        self.de_5=UpBlock(mid_channels*2,mid_channels)
        self.de_4=UpBlock(mid_channels*2,mid_channels)
        self.de_3=UpBlock(mid_channels*2,mid_channels)
        self.de_2=UpBlock(mid_channels*2,mid_channels)
        self.de_1=UpBlock(mid_channels*2,out_channels)

    def forward(self,x):
        first_layer=self.first_layer(x)
        en_1=self.en_1(first_layer)
        en_2=self.en_2(en_1)
        en_3=self.en_3(en_2)
        en_4=self.en_4(en_3)
        en_5=self.en_5(en_4)
        en_6=self.en_6(en_5)
        en_7=self.en_7(en_6)

        de_6=self.de_6(paddle.concat([en_6,en_7],axis=1))
        de_5=self.de_5(de_6,en_5)
        de_4=self.de_4(de_5,en_4)
        de_3=self.de_3(de_4,en_3)
        de_2=self.de_2(de_3,en_2)
        de_1=self.de_1(de_2,en_1)

        return first_layer+de_1

# 第二个stage的RSU或者说是作者论文中说的6层的RSU
class RSU2(nn.Layer):
    def __init__(self,in_channels,mid_channels,out_channels):
        super(RSU1, self).__init__()
        self.pool=nn.MaxPool2D(2)

        self.first_layer=ConvalutionBlock(in_channels,out_channels)

        self.en_1 = ConvalutionBlock(out_channels, mid_channels)
        self.en_2 = DownBlock(mid_channels, mid_channels)
        self.en_3 = DownBlock(mid_channels, mid_channels)
        self.en_4 = DownBlock(mid_channels, mid_channels)
        self.en_5 = DownBlock(mid_channels, mid_channels)
        self.en_6 = ConvalutionBlock(mid_channels, mid_channels, 2)

        self.de_5 = ConvalutionBlock(mid_channels * 2, mid_channels)
        self.de_4 = UpBlock(mid_channels * 2, mid_channels)
        self.de_3 = UpBlock(mid_channels * 2, mid_channels)
        self.de_2 = UpBlock(mid_channels * 2, mid_channels)
        self.de_1 = UpBlock(mid_channels * 2, out_channels)

    def forward(self,x):
        first_layer=self.first_layer(x)
        en_1=self.en_1(first_layer)
        en_2=self.en_2(en_1)
        en_3=self.en_3(en_2)
        en_4=self.en_4(en_3)
        en_5=self.en_5(en_4)
        en_6=self.en_6(en_5)

        de_5=self.de_5(paddle.concat([en_5,en_6],axis=1))
        de_4=self.de_4(de_5,en_4)
        de_3=self.de_3(de_4,en_3)
        de_2=self.de_2(de_3,en_2)
        de_1=self.de_1(de_2,en_1)

        return first_layer+de_1

# 第三个stage的RSU或者说是作者论文中说的5层的RSU
class RSU3(nn.Layer):
    def __init__(self,in_channels,mid_channels,out_channels):
        super(RSU1, self).__init__()
        self.pool=nn.MaxPool2D(2)

        self.first_layer=ConvalutionBlock(in_channels,out_channels)

        self.en_1 = ConvalutionBlock(out_channels, mid_channels)
        self.en_2 = DownBlock(mid_channels, mid_channels)
        self.en_3 = DownBlock(mid_channels, mid_channels)
        self.en_4 = DownBlock(mid_channels, mid_channels)
        self.en_5 = ConvalutionBlock(mid_channels, mid_channels, 2)

        self.de_4 = ConvalutionBlock(mid_channels * 2, mid_channels)
        self.de_3 = UpBlock(mid_channels * 2, mid_channels)
        self.de_2 = UpBlock(mid_channels * 2, mid_channels)
        self.de_1 = UpBlock(mid_channels * 2, out_channels)

    def forward(self,x):
        first_layer=self.first_layer(x)
        en_1=self.en_1(first_layer)
        en_2=self.en_2(en_1)
        en_3=self.en_3(en_2)
        en_4=self.en_4(en_3)
        en_5=self.en_5(en_4)

        de_4=self.de_4(paddle.concat([en_4,en_5],axis=1))
        de_3=self.de_3(de_4,en_3)
        de_2=self.de_2(de_3,en_2)
        de_1=self.de_1(de_2,en_1)

        return first_layer+de_1

# 第四个stage的RSU或者说是作者论文中说的4层的RSU
class RSU4(nn.Layer):
    def __init__(self,in_channels,mid_channels,out_channels):
        super(RSU1, self).__init__()
        self.pool=nn.MaxPool2D(2)

        self.first_layer=ConvalutionBlock(in_channels,out_channels)

        self.en_1 = ConvalutionBlock(out_channels, mid_channels)
        self.en_2 = DownBlock(mid_channels, mid_channels)
        self.en_3 = DownBlock(mid_channels, mid_channels)
        self.en_4 = ConvalutionBlock(mid_channels, mid_channels, 2)

        self.de_3 = ConvalutionBlock(mid_channels * 2, mid_channels)
        self.de_2 = UpBlock(mid_channels * 2, mid_channels)
        self.de_1 = UpBlock(mid_channels * 2, out_channels)

    def forward(self,x):
        first_layer=self.first_layer(x)
        en_1=self.en_1(first_layer)
        en_2=self.en_2(en_1)
        en_3=self.en_3(en_2)
        en_4=self.en_4(en_3)

        de_3=self.de_4(paddle.concat([en_3,en_4],axis=1))
        de_2=self.de_2(de_3,en_2)
        de_1=self.de_1(de_2,en_1)

        return first_layer+de_1

# 最后一个stage 不改变尺寸 只做特征提取
class RSU4F(nn.Layer):
    def __init__(self,in_channels,mid_channels,out_channels):
        super(RSU4F, self).__init__()
        self.first_layer=ConvalutionBlock(in_channels,out_channels)
        self.en_1 = ConvalutionBlock(out_channels, mid_channels)
        self.en_2 = ConvalutionBlock(mid_channels, mid_channels,2)
        self.en_3 = ConvalutionBlock(mid_channels, mid_channels,4)
        self.en_4 = ConvalutionBlock(mid_channels, mid_channels,8)

        self.de_3 = ConvalutionBlock(mid_channels * 2, mid_channels,4)
        self.de_2 = ConvalutionBlock(mid_channels * 2, mid_channels,2)
        self.de_1 = ConvalutionBlock(mid_channels * 2, mid_channels)

    def forward(self,x):
        first_layer=self.first_layer(x)
        en_1=self.en_1(first_layer)
        en_2=self.en_2(en_1)
        en_3=self.en_3(en_2)
        en_4=self.en_4(en_3)

        de_3=self.de_3(paddle.concat([en_3,en_4],axis=1))
        de_2=self.de_2(paddle.concat([en_2,de_3],axis=1))
        de_1=self.de_1(paddle.concat([en_1,de_2],axis=1))

        return first_layer+de_1

class U2Net(nn.Layer):
    def __init__(self,in_channel=3,out_channel=1):
        super(U2Net, self).__init__()
        self.in_channel=in_channel
        self.out_channel=out_channel

        self.en_1 = nn.Sequential(
            RSU1(self.in_channel,32,64),
            nn.MaxPool2D(2)
        )
        self.en_2 = nn.Sequential(
            RSU2(64, 32, 128),
            nn.MaxPool2D(2)
        )
        self.en_3 = nn.Sequential(
            RSU3(128, 64, 256),
            nn.MaxPool2D(2)
        )
        self.en_4 = nn.Sequential(
            RSU4(256,128, 512),
            nn.MaxPool2D(2)
        )
        self.en_5 = nn.Sequential(
            RSU4F(512, 256, 512),
            nn.MaxPool2D(2)
        )
        self.en_6=RSU4F(512, 256, 512)

        self.de_5=RSU4F(1024, 256, 512)
        self.de_4=RSU4(1024, 128, 256)
        self.de_3=RSU3(512, 64, 128)
        self.de_2=RSU2(256, 32, 64)
        self.de_1=RSU1(128, 32, 64)

        self.sup1 = nn.Conv2D(64,self.out_channel)
        self.sup2 = nn.Conv2D(64,self.out_channel)
        self.sup3 = nn.Conv2D(128,self.out_channel)
        self.sup4 = nn.Conv2D(512,self.out_channel)
        self.sup5 = nn.Conv2D(512,self.out_channel)
        self.sup6 = nn.Conv2D(512,self.out_channel)

        self.sup0=nn.Conv2D(6*self.out_channel,1,3,1,1)

    def forward(self,x):
        en_1=self.en_1(x)
        en_2=self.en_2(en_1)
        en_3=self.en_3(en_2)
        en_4=self.en_4(en_3)
        en_5=self.en_5(en_4)
        en_6=self.en_6(en_5)

        de_5=self.de_5(paddle.concat([en_5,en_6],axis=1))
        de_4=self.de_4(UpBlock(de_5,en_4))
        de_3=self.de_3(UpBlock(de_4,en_3))
        de_2=self.de_2(UpBlock(de_3,en_2))
        de_1=self.de_1(UpBlock(de_2,en_1))

        sup1 = self.sup1(de_1)
        sup2 = self.sup2(de_2)
        sup3 = self.sup3(de_3)
        sup4 = self.sup4(de_4)
        sup5 = self.sup5(de_5)
        sup6 = self.sup6(en_6)

        sup0=self.sup0(paddle.concat([sup1,sup2,sup3,sup4,sup5,sup6]))

        return sup0

if __name__ == '__main__':
    # x=paddle.randn(shape=[2,3,256,256])
    unet=U2Net()
    # print(net(x).shape)
    paddle.summary(unet, (1,3,256,256))


  • 3
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值