Unet网络架构
每个stage包括了编码和解码两个部分
代码中的编码和解码部分是分开的
即
Stage1=con1+up_concat1
Stage2=con2+up_concat2
Stage3=con3+up_concat3
Stage4=con4+up_concat4
同时这些stage并不包括池化层和反卷积层
代码debug可以看到模型的架构图,网络模型图中可见共有编码部分的4个3×3卷积层conv1234,1个最底部的centre层,解码部分的4个3×3卷积层up_concat1234,4个maxpooling池化层,1个final层,一共14个层(4+1+4+4+1)
Conv1234层 分别包括了两个Con2d的基本卷积层
Maxpooling即基本的池化层
Center 也包括两个最基本的池化层
Up_concat包括了两个基本的卷积层(一个conv)和一个逆卷积层up
final层代指最后一层的1×1卷积层,仅包括一个卷积层
附pytorch实现代码
class unetConv2(nn.Module): #UNET的卷积层00
def __init__(self,in_size,out_size,is_batchnorm):
super(unetConv2,self).__init__()
if is_batchnorm: #两个有batchnorm的con3*3卷积层
self.conv1=nn.Sequential(
nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True),
)
self.conv2=nn.Sequential(
nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True),
)
else:
self.conv1=nn.Sequential(
nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
nn.ReLU(inplace=True),
)
self.conv2=nn.Sequential(
nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
nn.ReLU(inplace=True)
)
def forward(self, inputs):
outputs=self.conv1(inputs)
outputs=self.conv2(outputs)
return outputs
class unetUp(nn.Module):
def __init__(self,in_size,out_size,is_deconv):
super(unetUp,self).__init__()
self.conv=unetConv2(in_size,out_size,False) #先进行两个unet网络自定义的等大小的unetConv2_3*3卷积操作(此处不带batchnorm过程了
if is_deconv: #解码卷积 二选一 网络中全部选择了逆卷积 而不是简单的上采样
self.up=nn.ConvTranspose2d(in_size,out_size,kernel_size=2,stride=2) #逆卷积
else:
self.up=nn.UpsamplingBilinear2d(scale_factor=2) #上采样 up-conv2*2
def forward(self, inputs1,inputs2):
outputs2=self.up(inputs2)
offset=outputs2.size()[2]-inputs1.size()[2]
padding=2*[offset//2,offset//2]
outputs1=F.pad(inputs1,padding) #padding is negative, size become smaller
return self.conv(torch.cat([outputs1,outputs2],1))
class unet(nn.Module):
# 初始化这个类,定义本类的所有属性(函数)分别都是什么,但那时并没有运行步骤,仅仅进行该类的声明
def __init__(self,feature_scale=4,n_classes=21,is_deconv=True,in_channels=3,is_batchnorm=True):
super(unet,self).__init__()
self.is_deconv=is_deconv
self.in_channels=in_channels
self.is_batchnorm=is_batchnorm
self.feature_scale=feature_scale
filters=[64,128,256,512,1024]
filters=[int(x/self.feature_scale) for x in filters]
print("filters",filters)
self.conv1=unetConv2(self.in_channels,filters[0],self.is_batchnorm) #调用了unet网络自己定义的卷积类(unetConv2)
self.maxpool1=nn.MaxPool2d(kernel_size=2)
self.conv2=unetConv2(filters[0],filters[1],self.is_batchnorm)
self.maxpool2=nn.MaxPool2d(kernel_size=2)
self.conv3=unetConv2(filters[1],filters[2],self.is_batchnorm)
self.maxpool3=nn.MaxPool2d(kernel_size=2)
self.conv4=unetConv2(filters[2],filters[3],self.is_batchnorm)
self.maxpool4=nn.MaxPool2d(kernel_size=2)
self.center=unetConv2(filters[3],filters[4],self.is_batchnorm)
#umsampling
self.up_concat4=unetUp(filters[4],filters[3],self.is_deconv) #调用了unet网络自己定义的unetup解码网络类(其中逆卷积选择的不是简单的上采样)
self.up_concat3=unetUp(filters[3],filters[2],self.is_deconv)
self.up_concat2=unetUp(filters[2],filters[1],self.is_deconv)
self.up_concat1=unetUp(filters[1],filters[0],self.is_deconv)
#final conv (without and concat)
self.final=nn.Conv2d(filters[0],n_classes,kernel_size=1)
#具体的传播函数,有调用初始化过程中的函数,开始进行网络的传播过程
def forward(self, inputs): #同时在此处进行其他类别的forward调用
conv1=self.conv1(inputs)
maxpool1=self.maxpool1(conv1)
conv2=self.conv2(maxpool1)
maxpool2=self.maxpool2(conv2)
conv3=self.conv3(maxpool2)
maxpool3=self.maxpool3(conv3)
conv4=self.conv4(maxpool3)
maxpool4=self.maxpool4(conv4)
center=self.center(maxpool4)
up4=self.up_concat4(conv4,center)
up3=self.up_concat3(conv3,up4)
up2=self.up_concat2(conv2,up3)
up1=self.up_concat1(conv1,up2)
final=self.final(up1)
return final
import torchvision
device = torch.device('cuda:0')
def count_param(model):
param_count = 0
for param in model.parameters():
param_count += param.view(-1).size()[0]
return param_count
if __name__=="__main__":
model=unet(feature_scale=1)
print(summary(model,(3, 572, 572),device="cpu"))
print("count_param : ",count_param(model))
# model = torchvision.models.vgg
# model.to(device)
# model = torchvision.models.vgg16()
#
# print(summary(model, (3, 224, 224)))
参数量的计算见上一篇博客