从零开始写出一个Unet的model.py文件

1.双卷积,池化,双卷积,池化,双卷积,池化,双卷积,池化,双卷积

上采样,拼接,双卷积,上采样,拼接,双卷积,上采样,拼接,双卷积,上采样,拼接,双卷积

最后388*388*64的特征图经过一个1*1卷积(两个卷积核)得到一个388*388*2的特征图(这是我们最后的输出分割图)

也就是说(双卷积+池化)*4+双卷积+(上采样+拼接+双卷积)*4+1*1卷积

其实你说是双卷积+池化)*4+(双卷积+上采样+拼接)*4+双卷积+1*1卷积

2.双卷积函数:传入两个参数,输入的通道数和经过双卷积后的输出通道数3

def double_conv(in_channels, out_channels):  # 双层卷积模型,神经网络最基本的框架
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),  # 加入Bn层提高网络泛化能力(防止过拟合),加收敛速度
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),  # 3指kernel_size,即卷积核3*3
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True))

#下面演示调用这个函数
double_conv(1, 32),表示输入通道数为1,输出通道数为32
#下面也是同理
double_conv(32, 64)
double_conv(64, 128)
double_conv(128, 256)
double_conv(256, 512)

3.model.py文件里面 

这一次我们要实现的unet比上面的结构多了一层,我们的输入图片是120*140*1,我们加一个卷积层变成120*120*1,再卷积或者padding变成128*128*1(接下来就放进Unet里面),所以我们的网络结构其实是如下图:

 整个Model.py文件,就是一个Class Unet,继承于nn.Module,然后里面有两个函数,一个初始化函数__init__,另一个forward函数(输入特征图到网络得到输出的一个整个过程),所以可以得到框架:

Class Unet(nn.Module):

   def __init__(.............):
       ...........
       ...........
   
   def  forward(self,x):
    
       ...........
       ...........

 这两个函数其实你看forward函数就好了

import torch
import torch.nn as nn


def double_conv(in_channels, out_channels):  # 双层卷积模型,神经网络最基本的框架
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),  # 加入Bn层提高网络泛化能力(防止过拟合),加收敛速度
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),  # 3指kernel_size,即卷积核3*3
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True))
class UNet(nn.Module):

    def __init__(self):
        super().__init__()

        self.adnet = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(2, 1), padding=0, stride=(2, 1)),
            nn.BatchNorm2d(1),  # 加入Bn层提高网络泛化能力(防止过拟合),加收敛速度
            nn.ReLU(inplace=True),
            nn.Conv2d(1, 1, kernel_size=3, padding=5, stride=1),
            nn.BatchNorm2d(1),  # 加入Bn层提高网络泛化能力(防止过拟合),加收敛速度
            nn.ReLU(inplace=True))

        self.dconv_down0 = double_conv(1, 32)
        self.dconv_down1 = double_conv(32, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)

        self.maxpool = nn.MaxPool2d(2)
        # self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.upsample4 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1)
        self.upsample3 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1)
        self.upsample2 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
        self.upsample1 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)
        self.upsample0 = nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1)

        self.dconv_up3 = double_conv(256 + 256, 256)
        self.dconv_up2 = double_conv(128 + 128, 128)
        self.dconv_up1 = double_conv(64 + 64, 64)
        self.dconv_up0 = double_conv(64, 32)

        self.conv_last = nn.Conv2d(16, 1, 1)

    def forward(self, x):
        # reshape
        x = self.adnet(x)  # 120*240*1的图片输入到这个函数里,变成128*128*1

        # encode
        conv0 = self.dconv_down0(x)  # 32x128x128
        x = self.maxpool(conv0)  # 32x64x64

        conv1 = self.dconv_down1(x)  # 64x64x64
        x = self.maxpool(conv1)  # 64x32x32

        conv2 = self.dconv_down2(x)  # 128x32x32
        x = self.maxpool(conv2)  # 128x16x16

        conv3 = self.dconv_down3(x)  # 256x16x16
        x = self.maxpool(conv3)  # 256x8x8

        x = self.dconv_down4(x)  # 512x8x8

        # decode
        x = self.upsample4(x)  # 256x16x16
        # 因为使用了3*3卷积核和 padding=1 的组合,所以卷积过程图像尺寸不发生改变,所以省去了crop操作!
        x = torch.cat([x, conv3], dim=1)  # 512x16x16

        x = self.dconv_up3(x)  # 256x16x16
        x = self.upsample3(x)  # 128x32x32
        x = torch.cat([x, conv2], dim=1)  # 256x32x32

        x = self.dconv_up2(x)  # 128x32x32
        x = self.upsample2(x)  # 64x64x64
        x = torch.cat([x, conv1], dim=1)  # 128x64x64

        x = self.dconv_up1(x)  # 64x64x64
        x = self.upsample1(x)  # 32x128x128
        x = torch.cat([x, conv0], dim=1)  # 64x128x128

        x = self.dconv_up0(x)  # 32x128x128
        x = self.upsample0(x)   # 16x256x256

        out = self.conv_last(x)  # 1x256x256

        return out

下面的代码也是行得通的:

import torch
import torch.nn as nn


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.double_conv(x)


class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        # left
        self.left_conv_1 = DoubleConv(3, 64)
        self.down_1 = nn.MaxPool2d(2, 2)

        self.left_conv_2 = DoubleConv(64, 128)
        self.down_2 = nn.MaxPool2d(2, 2)

        self.left_conv_3 = DoubleConv(128, 256)
        self.down_3 = nn.MaxPool2d(2, 2)

        self.left_conv_4 = DoubleConv(256, 512)
        self.down_4 = nn.MaxPool2d(2, 2)

        # center
        self.center_conv = DoubleConv(512, 1024)

        # right
        self.up_1 = nn.ConvTranspose2d(1024, 512, 2, 2)
        self.right_conv_1 = DoubleConv(1024, 512)

        self.up_2 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.right_conv_2 = DoubleConv(512, 256)

        self.up_3 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.right_conv_3 = DoubleConv(256, 128)

        self.up_4 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.right_conv_4 = DoubleConv(128, 64)

        # output
        self.output = nn.Conv2d(64, 3, 1, 1, 0)

    def forward(self, x):
        # left
        x1 = self.left_conv_1(x)
        x1_down = self.down_1(x1)

        x2 = self.left_conv_2(x1_down)
        x2_down = self.down_2(x2)

        x3 = self.left_conv_3(x2_down)
        x3_down = self.down_3(x3)

        x4 = self.left_conv_4(x3_down)
        x4_down = self.down_4(x4)

        # center
        x5 = self.center_conv(x4_down)

        # right
        x6_up = self.up_1(x5)
        temp = torch.cat((x6_up, x4), dim=1)
        x6 = self.right_conv_1(temp)

        x7_up = self.up_2(x6)
        temp = torch.cat((x7_up, x3), dim=1)
        x7 = self.right_conv_2(temp)

        x8_up = self.up_3(x7)
        temp = torch.cat((x8_up, x2), dim=1)
        x8 = self.right_conv_3(temp)

        x9_up = self.up_4(x8)
        temp = torch.cat((x9_up, x1), dim=1)
        x9 = self.right_conv_4(temp)

        # output
        output = self.output(x9)

        return output

if __name__ == "__main__":
    a = torch.rand(10, 3, 32, 32)
    model = UNet()
    b = model(a)
    print(b.size())  # 输出得到:torch.Size([10, 3, 32, 32])

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值