U-net经典结构学习(U-Net: Convolutional Networks for Biomedical Image Segmentation)
一、介绍
在本论文中,文章提出了一种网络和训练策略,该策略依赖于数据增强的强大使用,以更有效地使用可用的注释样本。该体系结构包括捕获上下文的收缩路径和支持精确定位的对称扩展路径。论文表明,这样的网络可以从很少的图像中进行端到端训练,并且在ISBI挑战中优于先前的最佳方法(滑动窗口卷积网络),以分割电子显微镜堆栈中的神经元结构。
因为专业是声源分离方面,所以读这篇论文只是了解到U-net的模型结构和代码以及U-net相对于先前卷积网络的有效性,为之后阅读其他论文做准备(后面许多论文是基于U-net的)。
二、U-net优势
1、论文修改和扩展了“全卷积网络”,使其适用于很少的训练图像,并产生更精确的分割。本论文架构中的一个重要修改是,在上采样部分,有大量的特征通道,这允许网络将上下文信息传播到更高分辨率的层。因此,扩张路径或多或少与收缩路径对称,并产生U形结构。
2、网络没有任何完全连接的层,只使用每个卷积的有效部分,即分割映射只包含像素,在输入图像中可以获得完整的上下文。
3、该策略允许通过重叠贴图策略对任意大的图像进行无缝分割。为了预测图像边界区域的像素,通过镜像输入图像来推断缺失的上下文。
三、U-net结构
U-net整体结构分为收缩路径和扩张路径两部分,因此整体呈现出一种U型结构。
U-net结构网络是一个经典的全卷积网络(即网络中没有全连接操作)。
网络的输入是一张 572 × 572 572 的边缘经过镜像操作的图片(input image tile),网络的左侧(红色虚线)是由卷积和Max Pooling构成的一系列降采样操作,论文中将这一部分叫做压缩路径(contracting path)。压缩路径由4个block组成,每个block使用了3个有效卷积(kernel size = 3×3)和1个Max Pooling降采样,每次降采样之后Feature Map的个数乘2,因此有了图中所示的Feature Map尺寸变化。最终得到了尺寸为32 × 32 的Feature Map。
网络的右侧部分(绿色虚线)在论文中叫做扩展路径(expansive path)。同样由4个block组成,每个block开始之前通过反卷积将Feature Map的尺寸乘2,同时将其个数减半(最后一层略有不同),然后和左侧对称的压缩路径的Feature Map合并,由于左侧压缩路径和右侧扩展路径的Feature Map的尺寸不一样,U-Net是通过将压缩路径的Feature Map裁剪到和扩展路径相同尺寸的Feature Map进行归一化的(即图1中左侧虚线部分)。扩展路径的卷积操作依旧使用的是有效卷积操作,最终得到的Feature Map的尺寸是338 × 338 。
四、U-net模型结构代码复现
import torch
from torch import nn
import torchvision.transforms.functional
# 卷积层
class DoubleConv(nn.Module):
"""
卷积过程
"""
def __init__(self,input_channel,output_channel):
super().__init__()
# 第一个卷积层
self.first = nn.Conv2d(input_channel,output_channel,kernel_size=3)
# Relu 激活函数
self.relu = nn.ReLU()
# 第二个卷积层
self.second = nn.Conv2d(output_channel,output_channel,kernel_size=3)
def forward(self,x):
x = self.relu(self.first(x))
# 打开print 即可查看运行时各层的参数
# print(x.shape)
x = self.relu(self.second(x))
# print(x.shape)
return x
# 下采样层
class DownSample(nn.Module):
"""
下采样过程,通道数增加,特征数减半
"""
def __init__(self):
super().__init__()
self.pooling = nn.MaxPool2d(2)
def forward(self,x):
return self.pooling(x)
class UpSample(nn.Module):
"""
上采样过程,通道数减少,特征数加倍
"""
def __init__(self,in_channels,out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels,out_channels,kernel_size=2,stride=2)
def forward(self,x):
x = self.up(x)
return x
class CropAndConcat(nn.Module):
"""
在扩展路径的每一步,来自收缩路径的相应特征图与当前特征图连接。
"""
def __init__(self):
super().__init__()
def forward(self,x,contracting_x):
contracting_x =torchvision.transforms.functional.center_crop(contracting_x,(x.shape[2],x.shape[3]))
x = torch.cat((x,contracting_x),dim = 1)
return x
class Unet (nn.Module):
def __init__(self,in_channels,out_channels):
super().__init__()
# 下采样卷积
self.down_conv = nn.ModuleList([DoubleConv(i,ou)
for i,ou in [(in_channels,64),(64,128),(128,256),(256,512)]])
# 下采样池化层
self.down_sample = nn.ModuleList(DownSample() for _ in range(4))
# 中间层
self.middle_conv = DoubleConv(512,1024)
# 上采样卷积
self.up_conv = nn.ModuleList([DoubleConv(i,ou)
for i,ou in [(1024,512),(512,256),(256,128),(128,64)]])
self.up_sample = nn.ModuleList([UpSample(i,ou)
for i,ou in [(1024,512),(512,256),(256,128),(128,64)]])
self.concat = [CropAndConcat() for _ in range(4)]
self.final_conv = nn.Conv2d(64,out_channels,kernel_size=1)
def forward(self,x):
# 记录收缩路径的相应特征图
pass_though= []
for i in range(len(self.down_conv)):
print("下采样第{}层".format(i))
x = self.down_conv[i](x)
pass_though.append(x)
x = self.down_sample[i](x)
print("中间层")
# print(x.shape)
x = self.middle_conv(x)
print(x.shape)
for i in range(len(self.up_conv)):
#print("上采样第{}层".format(i))
x = self.up_sample[i](x)
# print(x.shape)
x = self.concat[i](x,pass_though.pop())
#print(x.shape)
x = self.up_conv[i](x)
x = self.final_conv(x)
#print(x.shape)
return x
u = Unet(1,2)
# for para in u.parameters():
# # 打印Conv2d层的参数
# print(para.shape)
x = torch.ones([1,1,572,572])
print(u(x))