论文正文
概述
本文主要提出了U-Net这个网络结构,在少量数据集的情况下配合一定的数据增广,可以端到端训练,进行图像的语义分割。
以往基于深度学习的分割有两种方式:
方式1
:使用滑动窗口的策略得到一个个的patch
,patch就是当前像素以及它对应的上下文,将patch喂入模型,得到像素的类别。首先它很慢,并且不好把握准确率与patch大小的权衡。
方式2
:全卷积网络做分割。相关介绍,链接
U-Net就是在方式2的基础上得到的。
细节
网络结构
这个结构呈U字状,所以称为U-Net,这张图也将结构展现的很清晰了。输入图片经过四次的两个卷积快+一次下采样,然后再是四次的两个卷积块+一次上采样+skip-connnection的过程,最后三个1x1的卷积得到结果。
但是其中有一些细节如下:
- 卷积:这里的卷积是valid卷积,也就是没有padding的填充,所以会有尺寸的减小。或者按照以下的情况理解。传统的卷积模型有三种一种是full,也就是卷积核碰到图片就做卷积;另一种是same,也就是卷积核的中心碰到图片就做卷积,same卷积是最常见的卷积,在这种情况下,设定好kernel size,stride,padding就可以使得卷积前后图像尺寸不变;最后一种就是valid,卷积核图片重叠才做卷积。以下是找的几张图
- 下采样,这里的下采样用的是最大池化
- 上采样,这里的上采样用的是转置卷积,而不是平常使用的线性插值法,同时会把channel数减半。
- skip-connection,和FPN中一样,上采样结束之后,会与左边的特征进行concat操作,融合深层的语义级别特征和浅层的细粒度特征,但是,这里需要做一步crop操作。因为我们看图就能发现,两者的尺寸是不相同的,而不相同的原因就是因为前面的valid卷积。
- 最后,使用三次1x1的卷积将通道数慢慢降下来,得到最后两个通道的结果。是两个中通道的原因在于,做的是2分类,前景或者背景。但是看之前的DBNet得到的是一层的分割图,我的理解是:DBNet只是为了做文本的分割,所以每个像素点只需要计算属于文本实例的概率就好了。
Overlap-tile strategy
这个策略的理解我看网上有很多不同的解释,那我的理解是,黄色部分是原图,我们想要预测上面每个像素的类别(做分割),对于其中的每个像素,都需要一些上下文信息才行,但是其中会有一些问题:
首先
,边沿缺失上下文信息,怎么办呢?使用一些填充,将黄色部分扩展成蓝色。最简单的填充就是0填充,但是这种填充其实没有增加任何的上下文信息,那么作者选用的是镜像填充(mode=‘reflect’),即使用当前黄色框边沿的镜像作为填充。
其次
,使用了上面这种填充之后,临近的像素的填充肯定有很多重叠,怎么办呢?使用valid卷积,这样每次到了下一层,尺寸就会变小一点,慢慢的到后面这个填充就抵消掉了,得到最后的结果。
注
:Overlap-tile策略可搭配patch(图像分块)一起使用。当内存资源有限从而无法对整张大图进行预测时,可以对图像先进行镜像padding,然后按序将padding后的图像分割成固定大小的patch。这样,能够实现对任意大的图像进行无缝分割,同时每个图像块也获得了相应的上下文信息。另外,在数据量较少的情况下,每张图像都被分割成多个patch,相当于起到了扩充数据量的作用。更重要的是,这种策略不需要对原图进行缩放,每个位置的像素值与原图保持一致,不会因为缩放而带来误差。
训练
首先原图经过一次镜像填充,尺寸变大了,然后通过valid卷积,尺寸不断变小,最后得到最终的feature-map,它的尺寸就和原图一样了,这样的话,就可以训练了,我们逐像素做二元交叉熵损失就可以了。
简答实现
import paddle
import paddle.nn.functional as F
import paddle.nn as nn
# 两次卷积操作
# 卷积计算公式:
# 输出大小 = (输入大小 − Filter + 2Padding )/Stride+1
class VGGBlock(nn.Layer):
def __init__(self,in_channels,out_channels):
super(VGGBlock, self).__init__()
self.layer=nn.Sequential(
nn.Conv2D(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2D(out_channels),
nn.LeakyReLU(),
nn.Conv2D(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2D(out_channels),
nn.LeakyReLU()
)
def forward(self,x):
return self.layer(x)
# 将decoder当前层上采样并且和encoder当前层做concat
# 反卷积(转置卷积)计算公式:
# 输出大小 = (输入大小 − 1) * Stride + Filter - 2 * Padding
# 当前这种设置使得输入输出尺寸相同
class Up(nn.Layer):
def __init__(self,in_channels,out_channels):
super(Up, self).__init__()
self.layer=nn.Sequential(
nn.Conv2DTranspose(in_channels, out_channels, 4, 2, 1)
)
def forward(self,x1,x2):
x1=self.layer(x1)
# 因为tensor是ncwh的 我们需要在c维度上concat 所以axis是1
return paddle.concat([x2,x1],axis=1)
class UNet(nn.Layer):
def __init__(self,num_classes=2):
super(UNet, self).__init__()
filters=[64, 128, 256, 512, 1024]
self.pool= nn.MaxPool2D(2)
## -------------encoder-------------
self.encoder_1=VGGBlock(3,filters[0])
self.encoder_2=VGGBlock(filters[0],filters[1])
self.encoder_3=VGGBlock(filters[1],filters[2])
self.encoder_4=VGGBlock(filters[2],filters[3])
self.encoder_5=VGGBlock(filters[3],filters[4])
## -------------decoder-------------
self.up_4=Up(filters[4],filters[3])
self.up_3=Up(filters[3],filters[2])
self.up_2=Up(filters[2],filters[1])
self.up_1=Up(filters[1],filters[0])
self.decoder_4 = VGGBlock(filters[4],filters[3])
self.decoder_3 = VGGBlock(filters[3],filters[2])
self.decoder_2 = VGGBlock(filters[2],filters[1])
self.decoder_1 = VGGBlock(filters[1],filters[0])
self.final = nn.Sequential(
nn.Conv2D(filters[0],num_classes,3,1,1),
)
def forward(self,x):
## -------------encoder-------------
encoder_1=self.encoder_1(x)
encoder_2=self.encoder_2(self.pool(encoder_1))
encoder_3=self.encoder_3(self.pool(encoder_2))
encoder_4=self.encoder_4(self.pool(encoder_3))
encoder_5=self.encoder_5(self.pool(encoder_4))
## -------------decoder-------------
decoder_4=self.up_4(encoder_5,encoder_4)
decoder_4=self.decoder_4(decoder_4)
decoder_3 = self.up_3(decoder_4,encoder_3)
decoder_3=self.decoder_3(decoder_3)
decoder_2 = self.up_2(decoder_3,encoder_2)
decoder_2=self.decoder_2(decoder_2)
decoder_1 = self.up_1(decoder_2,encoder_1)
decoder_1=self.decoder_1(decoder_1)
output = self.final(decoder_1)
return output
if __name__ == '__main__':
# x=paddle.randn(shape=[2,3,256,256])
unet=UNet()
# print(net(x).shape)
paddle.summary(unet, (1,3,256,256))