U-Net入门(一)构建一个简单的U-Net网络

构建一个简单的U-Net网络

注:本文仅搭建网络结构,并未实现训练及预测

注:以下网络仅供学习交流,实际在投入训练过程中loss不能正确收敛,请不要直接使用,如有大佬知道什么情况,烦请指点一二


在这里插入图片描述

一、导包

import torch
from torch import nn

二、下采样

'''
    下采样,即U-Net的左半部分
'''
class block_down(nn.Module):
    def __init__(self,inp_channel,out_channel):
        """
        :param inp_channel: 输入通道数
        :param out_channel: 输出通道数
        """
        # 调用父类方法,传block_down是父类名字
        super(block_down,self).__init__()
        # 所有的卷积层添加padding=1会填充1个像素点,实现输入和输出的维度相同,也可以不选
        # 注:本文中所有关于维度的数据均未添加padding,如果需要添加padding,需要自己逐步调试计算维度
        # self.conv1 = nn.Conv2d(inp_channel, out_channel, 3, 1,padding=1)
        # self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1,padding=1)
        # 在这里面定义卷积层、标准化层和激活层,方便在forward方法中调用
        self.conv1=nn.Conv2d(inp_channel,out_channel,3,1)
        self.conv2=nn.Conv2d(out_channel,out_channel,3,1)
        # BatchNorm2d层是批量标准化层,可以加快收敛速度
        self.bn=nn.BatchNorm2d(out_channel)
        # 激活函数
        self.relu=nn.ReLU6(inplace=True)

    '''
        这个里面就是卷积两次,就是U-Net网络左半部分的每一行
        注意每次卷积之后都要标准化和激活
    '''
    def forward(self,x):
        x=self.conv1(x)
        x=self.bn(x)
        x=self.relu(x)
        x=self.conv2(x)
        x=self.bn(x)
        x=self.relu(x)
        return x

三、上采样

'''
    上采样模块,结构的右半部分
'''
class block_up(nn.Module):
    def __init__(self,inp_channel,out_channel,y):
        super(block_up,self).__init__()
        # 使用卷积转置实现上采样,增加输入特征图的尺寸
        self.up=nn.ConvTranspose2d(inp_channel,out_channel,2,2)
        # 所有的卷积层添加padding=1会填充1个像素点,实现输入和输出的维度相同,也可以不选
        self.conv1=nn.Conv2d(inp_channel,out_channel,3,1)
        self.conv2=nn.Conv2d(out_channel,out_channel,3,1)
        self.bn=nn.BatchNorm2d(out_channel)
        self.relu=nn.ReLU6(inplace=True)
        self.y=y

    def forward(self,x):
        x = self.up(x) # 上采样
        '''
            需要对y进行处理,因为传入的y和上采样得到的维度不同,无法直接进行拼接,需要进行裁剪
            首先以第一层上采样为例
            block6=block_up(1024,512,x4_use)
            x6=block6(x5)
            传入的x4_use的size是    ([1, 512, 52, 72]),在这里是y
            要与其拼接的x5的size是   ([1, 1024,22, 32]),在这里是x
            x经过一层上采样后的size是 ([1, 512, 44, 64])
            显然第三个维度不一致,无法进行拼接,y的第三个维度较大,需要进行裁剪
            在这里采用头尾裁切,取中间的方法
            所以有了关于delta的计算
            计算思路为52-44=8,两个数据差8,头尾各去掉4,即可实现拼接
            第四个维度同理
            即self.y=self.y[:,:,delta:self.y.shape[2]-delta,delta:self.y.shape[3]-delta]
        '''
        if self.y.shape[2]!=x.shape[2]:
            delta1=self.y.shape[2]-x.shape[2]
            delta=delta1//2
            '''
                以第二层
                block7=block_up(512,256,x3_use)
                x7=block7(x6)
                这层为例
                进入该方法时x3_use的Size是  ([1, 256, 113, 153]),在这个方法里是y
                要与其拼接的x6的Size是      ([1, 512, 40, 60]),在这个方法里是x
                经过一层上采样后,x的Size变为([1, 256, 80, 120]),显然也无法进行拼接
                经过计算delta1=33,delta=16
                这里存在一个问题,delta是奇数,经过整除运算余1,两边同时去掉delta还差1,维度同样不匹配
                所以需要在一边减去这个多余1,所以有了如下关于delta1奇偶的判断
                如果为奇数,则上限多-1
            '''
            if delta1%2==0:
                self.y=self.y[:,:,delta:self.y.shape[2]-delta,delta:self.y.shape[3]-delta]
            else:
                self.y=self.y[:,:,delta:self.y.shape[2]-delta-1,delta:self.y.shape[3]-delta-1]

        # 将x和y在第二个维度上进行拼接,就是第二个维度的加法操作
        x=torch.cat([x,self.y],dim=1)
        # 正常的每一行卷两下
        x=self.conv1(x)
        x=self.bn(x)
        x=self.relu(x)
        x=self.conv2(x)
        x=self.bn(x)
        x=self.relu(x)
        return x

四、构建网络

class U_net(nn.Module):
    def __init__(self,out_channel):
        super(U_net,self).__init__()
        self.out=nn.Conv2d(64,out_channel,1)
        # 使用最大池化实现下采样
        self.maxpool=nn.MaxPool2d(2)

    '''
        U-Net网络的架构
    '''
    def forward(self,x):
        # 下采样层
        # 一个block就是左半或者右半的一横行,就是卷积两下
        # use的作用在于crop and copy,裁剪和复制
        # use必须留在这里用于后续和上采样的对应层进行拼接,所以没有写进方法
        # 最大池化实现下采样
        block1=block_down(3,64)
        x1_use=block1(x) # torch.Size([1, 64, 476, 636])
        x1=self.maxpool(x1_use) #torch.Size([1, 64, 238, 318])

        block2=block_down(64,128) # torch.Size([1, 128, 119, 158])
        x2_use=block2(x1) # torch.Size([1, 128, 234, 314])
        x2=self.maxpool(x2_use) #torch.Size([1, 256, 117, 157])

        block3=block_down(128,256)
        x3_use=block3(x2) # torch.Size([1, 256, 113, 153])
        x3=self.maxpool(x3_use) # torch.Size([1, 256, 56, 76])

        block4=block_down(256,512)
        x4_use=block4(x3) # torch.Size([1, 512, 52, 72])
        x4=self.maxpool(x4_use) # torch.Size([1, 512, 26, 36])

        # 这层不需要池化了,到底层了
        block5=block_down(512,1024)
        x5=block5(x4) # torch.Size([1, 1024, 22, 32])

        # 上采样层
        # 上采样时我们把转置卷积写在方法里面了,所以没有跨行操作
        block6=block_up(1024,512,x4_use)
        x6=block6(x5)
        block7=block_up(512,256,x3_use)
        x7=block7(x6)
        block8=block_up(256,128,x2_use)
        x8=block8(x7)
        block9=block_up(128,64,x1_use)
        x9=block9(x8)
        x10=self.out(x9)
        out=nn.Softmax2d()(x10)
        return out

五、测试

'''
    main只是在测试输入输出的size是否一致,并没有执行训练,本文也并没有写训练相关代码
'''
if __name__=="__main__":
    # 创建一个测试输入,是一个随机图像
    test_input=torch.rand(1,3,480,640)
    # 输出测试输入的size
    print("test_input:",test_input.size())
    # 创建模型,输出通道为3
    model=U_net(out_channel=3)
    # 进行前向传播
    output=model(test_input)
    # 输出测试size
    print("output size:",output.size())
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值