概述
u2net是在unet的基础上提出来的,是一个效果非常棒的显著性目标检测模型。
显著性目标检测:分割出图像的主体。
模型提出的主要背景
是两部分:
1、分割任务的backbone主要是一些预训练模型。因为一些分割backbone的效果没有预训练模型的效果好,所以大家也都采用预训练模型了。问题
:他么并不是为了分割而设计的,对于分割任务中比较关键的局部细节和全局对比信息关注不够,所以想要更加好的适应分割还需要在这个基础上加一些特殊的结构,实现对于提取的特征的更好的利用,这就带来了计算复杂度。
2、模型一般比较深,考虑到显存和计算的开销,会在最前面的一些层中对图片先做一些下采样操作,降低分辨率。问题
:其实对于分割任务而言,高分辨率图像中丰富的空间信息是需要好好利用的。
u2net成功的解决了这两个问题,提出了两层嵌套的u形结构,可以从头训练网络而不依赖于预训练模型,特征提取的效果和预训练模型一样棒,并且不需要额外的结构处理这些特征。另外它在网络变深之后,还能保持高分辨率(因为RSU中大量的池化操作),并且带来的的显存和计算量的开销也是非常友好的。
细节
网络结构
以下就是u2net的网络结构,总体还是unet的u形结构,但是其中的每层或者每个stage由unet中朴素的卷积结构变为了RUS(Residual U-blocks),然后每个decoder的侧边输出都收到gt的监督,所有的侧边输出做concat之后,在进行卷积操作就是最终的网络输出了。
之前的研究还有堆叠或者级联unet得到
u
∗
n
−
n
e
t
u*n-net
u∗n−net,但是作者是嵌套或者指数表示unuet去了,显然嵌套的数量可以很大,即
u
n
n
e
t
u^n net
unnet,但是考虑实际情况,还是嵌套一层得到
u
2
n
e
t
u^2 net
u2net
然后每个RSU的层数随着encoder的层数的增加而减少,即En_1、En_2、En_3、En_4使用的分别是RSU-7、RSU-6、RSU-5、RSU-4,因为我们重视对于高分辨率特征图的特征提取,同时会进行池化操作,降低尺寸。而En_5、En_6采用的就是RSU-4F了,F表示不会变化尺寸,也就是只进行特征提取。
RSU(Residual U-blocks)
RUS替换了unet中朴素的卷积块,他能够更好的捕捉全局和局部的信息,而以往的1x1,3x3的卷积由于感受野的缘故,往往擅长捕捉局部的信息,对于全局信息的捕捉没有那么良好,而全局信息往往是分割分割所需要的。RUS通过这个u形状结构实现不同尺度不同感受野的特征图的混合,能够捕捉来自更多的不同尺度的全局信息。
并且他还使用了残差的思想。resnet中需要至少两层才能做恒等映射,不然就是做线性变换,而n次的线性变化效果等价与1次的线性变换。而本文中,由于这个u-block包含了若干层了,所以跨越一个block就行了。
然后是计算量方面,作者将一些主流的块结构进行了对比,发现虽然RSU的计算量随着深度的增加是线性的,但是系数很小,故计算量其实没有怎么大,是可以堆叠的很深的。
损失
主要就是两部分,一部分是侧边输出特征图的损失,另一部分是这些侧边输出融合之后形成的最终输出特征图的损失。
简单实现
import paddle
import paddle.nn as nn
# 卷积块
# 参数dilation 默认是1 即标准卷积,设置为2 表示膨胀卷积或者说是空洞卷积
class ConvalutionBlock(nn.Layer):
def __init__(self,in_channels,out_channels,dilation=1):
super(ConvalutionBlock, self).__init__()
self.layer=nn.Sequential(
nn.Conv2D(in_channels,out_channels,3,1,1,dilation),
nn.BatchNorm2D(out_channels),
nn.LeakyReLU()
)
def forward(self,x):
return self.layer(x)
# 下采样块
class DownBlock(nn.Layer):
def __init__(self,in_channels,out_channels):
super(DownBlock, self).__init__()
self.layer=nn.Sequential(
nn.MaxPool2D(1),
ConvalutionBlock(in_channels,out_channels)
)
def forward(self,x):
return self.layer(x)
# 上采样块
# 将decoder当前层上采样并且和encoder当前层做concat
# 然后再经过3x3的卷积块
class UpBlock(nn.Layer):
def __init__(self,in_channels,out_channels):
super(UpBlock, self).__init__()
self.upsample=nn.UpsamplingBilinear2D(scale_factor=2)
self.conv=ConvalutionBlock(in_channels,out_channels)
def forward(self, x1, x2):
x1 = self.upsample(x1)
# 因为tensor是ncwh的 我们需要在c维度上concat 所以axis是1
x=paddle.concat([x2, x1], axis=1)
return self.conv(x)
# 第一个stage的RSU或者说是作者论文中说的7层的RSU
class RSU1(nn.Layer):
def __init__(self,in_channels,mid_channels,out_channels):
super(RSU1, self).__init__()
self.pool=nn.MaxPool2D(2)
self.first_layer=ConvalutionBlock(in_channels,out_channels)
self.en_1=ConvalutionBlock(out_channels,mid_channels)
self.en_2=DownBlock(mid_channels,mid_channels)
self.en_3=DownBlock(mid_channels,mid_channels)
self.en_4=DownBlock(mid_channels,mid_channels)
self.en_5=DownBlock(mid_channels,mid_channels)
self.en_6=DownBlock(mid_channels,mid_channels)
self.en_7=ConvalutionBlock(mid_channels,mid_channels,2)
self.de_6=ConvalutionBlock(mid_channels*2,mid_channels)
self.de_5=UpBlock(mid_channels*2,mid_channels)
self.de_4=UpBlock(mid_channels*2,mid_channels)
self.de_3=UpBlock(mid_channels*2,mid_channels)
self.de_2=UpBlock(mid_channels*2,mid_channels)
self.de_1=UpBlock(mid_channels*2,out_channels)
def forward(self,x):
first_layer=self.first_layer(x)
en_1=self.en_1(first_layer)
en_2=self.en_2(en_1)
en_3=self.en_3(en_2)
en_4=self.en_4(en_3)
en_5=self.en_5(en_4)
en_6=self.en_6(en_5)
en_7=self.en_7(en_6)
de_6=self.de_6(paddle.concat([en_6,en_7],axis=1))
de_5=self.de_5(de_6,en_5)
de_4=self.de_4(de_5,en_4)
de_3=self.de_3(de_4,en_3)
de_2=self.de_2(de_3,en_2)
de_1=self.de_1(de_2,en_1)
return first_layer+de_1
# 第二个stage的RSU或者说是作者论文中说的6层的RSU
class RSU2(nn.Layer):
def __init__(self,in_channels,mid_channels,out_channels):
super(RSU1, self).__init__()
self.pool=nn.MaxPool2D(2)
self.first_layer=ConvalutionBlock(in_channels,out_channels)
self.en_1 = ConvalutionBlock(out_channels, mid_channels)
self.en_2 = DownBlock(mid_channels, mid_channels)
self.en_3 = DownBlock(mid_channels, mid_channels)
self.en_4 = DownBlock(mid_channels, mid_channels)
self.en_5 = DownBlock(mid_channels, mid_channels)
self.en_6 = ConvalutionBlock(mid_channels, mid_channels, 2)
self.de_5 = ConvalutionBlock(mid_channels * 2, mid_channels)
self.de_4 = UpBlock(mid_channels * 2, mid_channels)
self.de_3 = UpBlock(mid_channels * 2, mid_channels)
self.de_2 = UpBlock(mid_channels * 2, mid_channels)
self.de_1 = UpBlock(mid_channels * 2, out_channels)
def forward(self,x):
first_layer=self.first_layer(x)
en_1=self.en_1(first_layer)
en_2=self.en_2(en_1)
en_3=self.en_3(en_2)
en_4=self.en_4(en_3)
en_5=self.en_5(en_4)
en_6=self.en_6(en_5)
de_5=self.de_5(paddle.concat([en_5,en_6],axis=1))
de_4=self.de_4(de_5,en_4)
de_3=self.de_3(de_4,en_3)
de_2=self.de_2(de_3,en_2)
de_1=self.de_1(de_2,en_1)
return first_layer+de_1
# 第三个stage的RSU或者说是作者论文中说的5层的RSU
class RSU3(nn.Layer):
def __init__(self,in_channels,mid_channels,out_channels):
super(RSU1, self).__init__()
self.pool=nn.MaxPool2D(2)
self.first_layer=ConvalutionBlock(in_channels,out_channels)
self.en_1 = ConvalutionBlock(out_channels, mid_channels)
self.en_2 = DownBlock(mid_channels, mid_channels)
self.en_3 = DownBlock(mid_channels, mid_channels)
self.en_4 = DownBlock(mid_channels, mid_channels)
self.en_5 = ConvalutionBlock(mid_channels, mid_channels, 2)
self.de_4 = ConvalutionBlock(mid_channels * 2, mid_channels)
self.de_3 = UpBlock(mid_channels * 2, mid_channels)
self.de_2 = UpBlock(mid_channels * 2, mid_channels)
self.de_1 = UpBlock(mid_channels * 2, out_channels)
def forward(self,x):
first_layer=self.first_layer(x)
en_1=self.en_1(first_layer)
en_2=self.en_2(en_1)
en_3=self.en_3(en_2)
en_4=self.en_4(en_3)
en_5=self.en_5(en_4)
de_4=self.de_4(paddle.concat([en_4,en_5],axis=1))
de_3=self.de_3(de_4,en_3)
de_2=self.de_2(de_3,en_2)
de_1=self.de_1(de_2,en_1)
return first_layer+de_1
# 第四个stage的RSU或者说是作者论文中说的4层的RSU
class RSU4(nn.Layer):
def __init__(self,in_channels,mid_channels,out_channels):
super(RSU1, self).__init__()
self.pool=nn.MaxPool2D(2)
self.first_layer=ConvalutionBlock(in_channels,out_channels)
self.en_1 = ConvalutionBlock(out_channels, mid_channels)
self.en_2 = DownBlock(mid_channels, mid_channels)
self.en_3 = DownBlock(mid_channels, mid_channels)
self.en_4 = ConvalutionBlock(mid_channels, mid_channels, 2)
self.de_3 = ConvalutionBlock(mid_channels * 2, mid_channels)
self.de_2 = UpBlock(mid_channels * 2, mid_channels)
self.de_1 = UpBlock(mid_channels * 2, out_channels)
def forward(self,x):
first_layer=self.first_layer(x)
en_1=self.en_1(first_layer)
en_2=self.en_2(en_1)
en_3=self.en_3(en_2)
en_4=self.en_4(en_3)
de_3=self.de_4(paddle.concat([en_3,en_4],axis=1))
de_2=self.de_2(de_3,en_2)
de_1=self.de_1(de_2,en_1)
return first_layer+de_1
# 最后一个stage 不改变尺寸 只做特征提取
class RSU4F(nn.Layer):
def __init__(self,in_channels,mid_channels,out_channels):
super(RSU4F, self).__init__()
self.first_layer=ConvalutionBlock(in_channels,out_channels)
self.en_1 = ConvalutionBlock(out_channels, mid_channels)
self.en_2 = ConvalutionBlock(mid_channels, mid_channels,2)
self.en_3 = ConvalutionBlock(mid_channels, mid_channels,4)
self.en_4 = ConvalutionBlock(mid_channels, mid_channels,8)
self.de_3 = ConvalutionBlock(mid_channels * 2, mid_channels,4)
self.de_2 = ConvalutionBlock(mid_channels * 2, mid_channels,2)
self.de_1 = ConvalutionBlock(mid_channels * 2, mid_channels)
def forward(self,x):
first_layer=self.first_layer(x)
en_1=self.en_1(first_layer)
en_2=self.en_2(en_1)
en_3=self.en_3(en_2)
en_4=self.en_4(en_3)
de_3=self.de_3(paddle.concat([en_3,en_4],axis=1))
de_2=self.de_2(paddle.concat([en_2,de_3],axis=1))
de_1=self.de_1(paddle.concat([en_1,de_2],axis=1))
return first_layer+de_1
class U2Net(nn.Layer):
def __init__(self,in_channel=3,out_channel=1):
super(U2Net, self).__init__()
self.in_channel=in_channel
self.out_channel=out_channel
self.en_1 = nn.Sequential(
RSU1(self.in_channel,32,64),
nn.MaxPool2D(2)
)
self.en_2 = nn.Sequential(
RSU2(64, 32, 128),
nn.MaxPool2D(2)
)
self.en_3 = nn.Sequential(
RSU3(128, 64, 256),
nn.MaxPool2D(2)
)
self.en_4 = nn.Sequential(
RSU4(256,128, 512),
nn.MaxPool2D(2)
)
self.en_5 = nn.Sequential(
RSU4F(512, 256, 512),
nn.MaxPool2D(2)
)
self.en_6=RSU4F(512, 256, 512)
self.de_5=RSU4F(1024, 256, 512)
self.de_4=RSU4(1024, 128, 256)
self.de_3=RSU3(512, 64, 128)
self.de_2=RSU2(256, 32, 64)
self.de_1=RSU1(128, 32, 64)
self.sup1 = nn.Conv2D(64,self.out_channel)
self.sup2 = nn.Conv2D(64,self.out_channel)
self.sup3 = nn.Conv2D(128,self.out_channel)
self.sup4 = nn.Conv2D(512,self.out_channel)
self.sup5 = nn.Conv2D(512,self.out_channel)
self.sup6 = nn.Conv2D(512,self.out_channel)
self.sup0=nn.Conv2D(6*self.out_channel,1,3,1,1)
def forward(self,x):
en_1=self.en_1(x)
en_2=self.en_2(en_1)
en_3=self.en_3(en_2)
en_4=self.en_4(en_3)
en_5=self.en_5(en_4)
en_6=self.en_6(en_5)
de_5=self.de_5(paddle.concat([en_5,en_6],axis=1))
de_4=self.de_4(UpBlock(de_5,en_4))
de_3=self.de_3(UpBlock(de_4,en_3))
de_2=self.de_2(UpBlock(de_3,en_2))
de_1=self.de_1(UpBlock(de_2,en_1))
sup1 = self.sup1(de_1)
sup2 = self.sup2(de_2)
sup3 = self.sup3(de_3)
sup4 = self.sup4(de_4)
sup5 = self.sup5(de_5)
sup6 = self.sup6(en_6)
sup0=self.sup0(paddle.concat([sup1,sup2,sup3,sup4,sup5,sup6]))
return sup0
if __name__ == '__main__':
# x=paddle.randn(shape=[2,3,256,256])
unet=U2Net()
# print(net(x).shape)
paddle.summary(unet, (1,3,256,256))