UNet是一种用于图像分割的深度学习网络结构,由Olaf Ronneberger、Philipp Fischer和Thomas Brox于2015年提出。它在医学图像分割领域得到广泛应用,尤其在生物医学图像分析中。UNet的主要作用是将输入图像分割为像素级别的预测结果,即对每个像素进行分类,判断其属于哪个类别。这使得UNet在许多任务中非常有用,如细胞分割、肿瘤检测、器官分割等。
网络结构分析
UNet的结构可以分为两个部分:编码器(Encoder)和解码器(Decoder)。编码器负责逐步降低特征图的空间分辨率,提取图像的高级语义信息。解码器则通过上采样操作逐步恢复特征图的空间分辨率,并将低级特征与高级特征进行融合,以生成精细的分割结果。
首先解释一下上图: 蓝/白色框表示feature map(特征图); 蓝色箭头表示3*3卷积,主要用于特征提取; 灰色箭头表示skip-connection(跳跃连接,通常用于残差网络中),在这里是用于用于特征融合,其中copy and crop中的copy就是concatenate而crop(剪切图像)是为了让两者的长宽一致; 红色箭头表示池化 pooling,用于降低维度; 绿色箭头表示上采样 upsample,用于恢复维度; 青色箭头表示1*1卷积,用于输出结果; 左侧为特征提取部分,红色箭头为通过一次max pooling 得到一个新的尺度; 右侧为上采样部分,在这里每上采样一次就相当于和特征提取部分对应的通道数相同尺度融合,在这里,融合之前需要将其crop,左侧图中可以隐约看到蓝色的虚线框,那就是crop的过程
。
手撸代码:
import torch
from torch import nn
from torch.nn import functional as F
class Conv_Block(nn.Module):
def __init__(self, in_channel, out_channel):
super(Conv_Block, self).__init__()方法
self.layer = nn.Sequential(
nn.Conv2d(in_channel, out_channel, 3, 1, 1, padding_mode='reflect', bias=False),
nn.BatchNorm2d(out_channel),
nn.Dropout2d(0.3),
nn.LeakyReLU(),
nn.Conv2d(out_channel,out_channel, 3, 1, 1, padding_mode='reflect', bias=False),
nn.BatchNorm2d(out_channel),
nn.Dropout2d(0.3),
nn.LeakyReLU()
)
def forward(self,x):
return self.layer(x)
#下采样
class DownSample(nn.Module):
def __init__(self, channel):#下采样也是有输入通道的
super(DownSample, self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(channel, channel, 3, 2, 1, padding_mode='reflect', bias=False),
nn.BatchNorm2d(channel),
nn.LeakyReLU()
)
def forward(self, x):
return self.layer(x)
#上采样
class UpSample(nn.Module):
def __init__(self, channel):
super(UpSample, self).__init__()
self.layer = nn.Conv2d(channel, channel//2, 1, 1, )
def forward(self, x, feature_map):
up = F.interpolate(x, scale_factor=2, mode='nearest')
out = self.layer(up)
return torch.cat((out, feature_map), dim=1)
#网络结构
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.c1 = Conv_Block(3, 64)#卷积,输入是3个通道,输出是64个通道
self.d1 = DownSample(64)#下采样
self.c2 = Conv_Block(64, 128)
self.d2 = DownSample(128)
self.c3 = Conv_Block(128, 256)
self.d3 = DownSample(256)
self.c4 = Conv_Block(256, 512)
self.d4 = DownSample(512)
self.c5 = Conv_Block(512, 1024)
self.u1 = UpSample(1024)
self.c6 = Conv_Block(1024, 512)
self.u2 = UpSample(512)
self.c7 = Conv_Block(512, 256)
self.u3 = UpSample(256)
self.c8 = Conv_Block(256, 128)
self.u4 = UpSample(128)
self.c9 = Conv_Block(128, 64)
self.out = nn.Conv2d(64, 3, 3, 1, 1)
self.Th = nn.Sigmoid()
def forward(self, x):
#下采样
R1 = self.c1(x)
R2 = self.c2(self.d1(R1))
R3 = self.c3(self.d2(R2))
R4 = self.c4(self.d3(R3))
R5 = self.c5(self.d4(R4))
#上采样
O1 = self.c6(self.u1(R5, R4))
O2 = self.c7(self.u2(O1, R3))
O3 = self.c8(self.u3(O2, R2))
O4 = self.c9(self.u4(O3, R1))
return self.Th(self.out(O4))
if __name__ == '__main__':
x = torch.randn(2, 3, 256, 256)#2个批次,3个通道,256*256的图片
net = UNet()
print(net(x).shape)