U-net网络架构个人理解

Unet网络架构
Unet网络架构
每个stage包括了编码和解码两个部分
代码中的编码和解码部分是分开的

Stage1=con1+up_concat1
Stage2=con2+up_concat2
Stage3=con3+up_concat3
Stage4=con4+up_concat4
同时这些stage并不包括池化层和反卷积层

在这里插入图片描述
代码debug可以看到模型的架构图,网络模型图中可见共有编码部分的4个3×3卷积层conv1234,1个最底部的centre层,解码部分的4个3×3卷积层up_concat1234,4个maxpooling池化层,1个final层,一共14个层(4+1+4+4+1)

Conv1234层 分别包括了两个Con2d的基本卷积层
在这里插入图片描述
Maxpooling即基本的池化层
Center 也包括两个最基本的池化层
Up_concat包括了两个基本的卷积层(一个conv)和一个逆卷积层up
在这里插入图片描述
final层代指最后一层的1×1卷积层,仅包括一个卷积层

附pytorch实现代码

class unetConv2(nn.Module):                #UNET的卷积层00
    def __init__(self,in_size,out_size,is_batchnorm):
        super(unetConv2,self).__init__()

        if is_batchnorm:                                 #两个有batchnorm的con3*3卷积层
            self.conv1=nn.Sequential(
                nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
                nn.BatchNorm2d(out_size),
                nn.ReLU(inplace=True),
            )
            self.conv2=nn.Sequential(
                nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
                nn.BatchNorm2d(out_size),
                nn.ReLU(inplace=True),
            )
        else:
            self.conv1=nn.Sequential(
                nn.Conv2d(in_size,out_size,kernel_size=3,stride=1,padding=0),
                nn.ReLU(inplace=True),
            )
            self.conv2=nn.Sequential(
                nn.Conv2d(out_size,out_size,kernel_size=3,stride=1,padding=0),
                nn.ReLU(inplace=True)
            )
    def forward(self, inputs):
        outputs=self.conv1(inputs)
        outputs=self.conv2(outputs)

        return outputs

class unetUp(nn.Module):
    def __init__(self,in_size,out_size,is_deconv):
        super(unetUp,self).__init__()
        self.conv=unetConv2(in_size,out_size,False)  #先进行两个unet网络自定义的等大小的unetConv2_3*3卷积操作(此处不带batchnorm过程了
        if is_deconv:                       #解码卷积  二选一 网络中全部选择了逆卷积 而不是简单的上采样
            self.up=nn.ConvTranspose2d(in_size,out_size,kernel_size=2,stride=2)  #逆卷积
        else:
            self.up=nn.UpsamplingBilinear2d(scale_factor=2)  #上采样  up-conv2*2

    def forward(self, inputs1,inputs2):
        outputs2=self.up(inputs2)
        offset=outputs2.size()[2]-inputs1.size()[2]
        padding=2*[offset//2,offset//2]
        outputs1=F.pad(inputs1,padding)     #padding is negative, size become smaller

        return self.conv(torch.cat([outputs1,outputs2],1))

class unet(nn.Module):
    # 初始化这个类,定义本类的所有属性(函数)分别都是什么,但那时并没有运行步骤,仅仅进行该类的声明
    def __init__(self,feature_scale=4,n_classes=21,is_deconv=True,in_channels=3,is_batchnorm=True):
        super(unet,self).__init__()
        self.is_deconv=is_deconv
        self.in_channels=in_channels
        self.is_batchnorm=is_batchnorm
        self.feature_scale=feature_scale

        filters=[64,128,256,512,1024]
        filters=[int(x/self.feature_scale) for x in filters]
        print("filters",filters)
        self.conv1=unetConv2(self.in_channels,filters[0],self.is_batchnorm)    #调用了unet网络自己定义的卷积类(unetConv2)
        self.maxpool1=nn.MaxPool2d(kernel_size=2)

        self.conv2=unetConv2(filters[0],filters[1],self.is_batchnorm)
        self.maxpool2=nn.MaxPool2d(kernel_size=2)

        self.conv3=unetConv2(filters[1],filters[2],self.is_batchnorm)
        self.maxpool3=nn.MaxPool2d(kernel_size=2)

        self.conv4=unetConv2(filters[2],filters[3],self.is_batchnorm)
        self.maxpool4=nn.MaxPool2d(kernel_size=2)

        self.center=unetConv2(filters[3],filters[4],self.is_batchnorm)

        #umsampling
        self.up_concat4=unetUp(filters[4],filters[3],self.is_deconv)      #调用了unet网络自己定义的unetup解码网络类(其中逆卷积选择的不是简单的上采样)
        self.up_concat3=unetUp(filters[3],filters[2],self.is_deconv)
        self.up_concat2=unetUp(filters[2],filters[1],self.is_deconv)
        self.up_concat1=unetUp(filters[1],filters[0],self.is_deconv)

        #final conv (without and concat)
        self.final=nn.Conv2d(filters[0],n_classes,kernel_size=1)

    #具体的传播函数,有调用初始化过程中的函数,开始进行网络的传播过程
    def forward(self, inputs):    #同时在此处进行其他类别的forward调用
        conv1=self.conv1(inputs)
        maxpool1=self.maxpool1(conv1)

        conv2=self.conv2(maxpool1)
        maxpool2=self.maxpool2(conv2)

        conv3=self.conv3(maxpool2)
        maxpool3=self.maxpool3(conv3)

        conv4=self.conv4(maxpool3)
        maxpool4=self.maxpool4(conv4)

        center=self.center(maxpool4)
        up4=self.up_concat4(conv4,center)
        up3=self.up_concat3(conv3,up4)
        up2=self.up_concat2(conv2,up3)
        up1=self.up_concat1(conv1,up2)

        final=self.final(up1)

        return final


import torchvision

device = torch.device('cuda:0')


def count_param(model):
    param_count = 0
    for param in model.parameters():
        param_count += param.view(-1).size()[0]
    return param_count

if __name__=="__main__":
    model=unet(feature_scale=1)
    print(summary(model,(3, 572, 572),device="cpu"))
    print("count_param  : ",count_param(model))
    # model = torchvision.models.vgg
    # model.to(device)
    # model = torchvision.models.vgg16()
    #
    # print(summary(model, (3, 224, 224)))

参数量的计算见上一篇博客

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值