Unet网络结构:
(1)UNet网络为一个U形结构。采用全卷积神经网络,没有全连接操作
(2)左边encoder编码器部分为特征提取网络:使用conv和pooling进行下采样
(3)右边decoder解码器部分为特征融合网络:右侧上采样产生的特征图与左侧下采样的特征图在channel维度上进行concatenate拼接操作。(图片的维度为:B C H W,分别为batchsize,channel, height, width)
上采样的目的:pooling池化层使得图片宽高减半,会丢失图像信息,降低分辨率。上采样可以提高图片分辨率,并且保留高级抽象特征,然后再与左边低级表层特征高分辨率图片拼接。
上采样方法:使用转置卷积nn.ConvTranspose2d()代替简单的插值上采样方法,既能实现同样的效果,也能加深网络。
(4)最后再经过两次3*3卷积操作,再用1*1的卷积核,输出channel维度为需要分割的类别数num_classes,生成维度为(B,num_classes,H,W)的特征图。
code:
import torch
import torch.nn as nn
import torch.nn.functional as F
def X2conv(in_channel,out_channel):
"""连续两个3*3卷积"""
return nn.Sequential(
nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(),
nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU())
class DownsampleLayer(nn.Module):
"""
下采样层
"""
def __init__(self,in_channel,out_channel):
super(DownsampleLayer, self).__init__()
self.x2conv=X2conv(in_channel,out_channel)
self.pool=nn.MaxPool2d(kernel_size=2,ceil_mode=True)
def forward(self,x):
"""
:param x:上一层pool后的特征
:return: out_1转入右侧(待拼接),out_1输入到下一层,
"""
out_1=self.x2conv(x)
out=self.pool(out_1)
return out_1,out
class UpSampleLayer(nn.Module):
"""
上采样层
"""
def __init__(self,in_channel,out_channel):
super(UpSampleLayer, self).__init__()
self.x2conv = X2conv(in_channel, out_channel)
self.upsample=nn.ConvTranspose2d\
(in_channels=out_channel,out_channels=out_channel//2,kernel_size=3,stride=2,padding=1)
def forward(self,x,out):
'''
:param x: decoder中:输入层特征,经过x2conv与上采样upsample,然后拼接
:param out:左侧encoder层中特征(与右侧上采样层进行cat)
:return:
'''
x=self.x2conv(x)
x=self.upsample(x)
# x.shape中H W 应与 out.shape中的H W相同
if (x.size(2) != out.size(2)) or (x.size(3) != out.size(3)):
# 将右侧特征H W大小插值变为左侧特征H W大小
x = F.interpolate(x, size=(out.size(2), out.size(3)),
mode="bilinear", align_corners=True)
# Concatenate(在channel维度)
cat_out = torch.cat([x, out], dim=1)
return cat_out
class UNet(nn.Module):
"""
UNet模型,num_classes为分割类别数
"""
def __init__(self,num_classes):
super(UNet, self).__init__()
#下采样
self.d1=DownsampleLayer(3,64) #3-64
self.d2=DownsampleLayer(64,128)#64-128
self.d3=DownsampleLayer(128,256)#128-256
self.d4=DownsampleLayer(256,512)#256-512
#上采样
self.u1=UpSampleLayer(512,1024)#512-1024-512
self.u2=UpSampleLayer(1024,512)#1024-512-256
self.u3=UpSampleLayer(512,256)#512-256-128
self.u4=UpSampleLayer(256,128)#256-128-64
#输出:经过一个二层3*3卷积 + 1个1*1卷积
self.x2conv=X2conv(128,64)
self.final_conv=nn.Conv2d(64,num_classes,kernel_size=1) # 最后一个卷积层的输出通道数为分割的类别数
self._initialize_weights()
def _initialize_weights(self):
for module in self.modules():
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm2d):
module.weight.data.fill_(1)
module.bias.data.zero_()
def forward(self,x):
# 下采样层
out_1,out1=self.d1(x)
out_2,out2=self.d2(out1)
out_3,out3=self.d3(out2)
out_4,out4=self.d4(out3)
# 上采样层 拼接
out5=self.u1(out4,out_4)
out6=self.u2(out5,out_3)
out7=self.u3(out6,out_2)
out8=self.u4(out7,out_1)
# 最后的三层卷积
out=self.x2conv(out8)
out=self.final_conv(out)
return out
if __name__ == "__main__":
img = torch.randn((2, 3, 360, 480)) # 正态分布初始化
model = UNet(num_classes=16)
output = model(img)
print(output.shape)