Pytorch搭建UNet网络

Pytorch搭建UNet网络

前言

学习一下经典的语义分割网络U-Net

原理

在这里插入图片描述
介绍一下UNet的网络结构:
UNet是一个U型的网络结构,左侧的半个U是特征提取网络也就是编码网络,右侧的半个U是解码网络。在左侧和右侧之间将对应的特征图按照通道进行连接(concatenate),从而实现不同层次特征图之间的一种信息融合。
特征提取网络是经典的VGG风格的网络,通过两或三个重复堆叠的3×3卷积层+ReLU层形成vgg_block,共5个vgg_block,每个vgg_block之间通过最大池化层进行下采样,使宽高减半。
解码网络的每一层先通过一个kernel_size为2×2,stride为2的转置卷积层进行上采样,使通道数减半、宽高加倍;接着与特征提取网络中对应的特征层按照通道进行连接;然后经过两个3×3卷积层+ReLU层。共4次上采样。
最后输出层通过1×1卷积层将通道数映射为所需类数,用softmax激活函数输出概率图。

UNet的好处:①浅层卷积关注纹理特征,深层网络关注本质的语义特征,UNet通过连接的方式兼顾两者②特征提取下采样会丢失一些边缘信息,而这无法通过上采样学习到,通过连接可以实现找回边缘信息,使边缘预测更精确。

代码实现

首先是下采样块,类似vgg_block。注意一般情况下图片和标签尺寸是一样的,所以这里卷积层里都用的padding=1,与图上略有不同。

class DownBlock(nn.Module):
    def __init__(self, num_convs, inchannels, outchannels, pool=True):
        super(DownBlock, self).__init__()
        blk = []
        if pool:
            blk.append(nn.MaxPool2d(kernel_size=2, stride=2))
        for i in range(num_convs):
            if i == 0:
                blk.append(nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1))
            else:
                blk.append(nn.Conv2d(outchannels, outchannels, kernel_size=3, padding=1))
            blk.append(nn.ReLU(inplace=True))
        self.layer = nn.Sequential(*blk)

    def forward(self, x):
        return self.layer(x)

接着是上采样块,先通过一个kernel_size为2×2,stride为2的转置卷积层进行上采样;接着与特征提取网络中对应的特征层进行连接;然后经过两个3×3卷积层+ReLU层。

class UpBlock(nn.Module):
    def __init__(self, inchannels, outchannels):
        super(UpBlock, self).__init__()
        self.convt = nn.ConvTranspose2d(inchannels, outchannels, kernel_size=2, stride=2)
        self.conv = nn.Sequential(
            nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannels, outchannels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x1, x2):
        x1 = self.convt(x1)
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

整体UNet网络,5次下采样,4次上采样。

class UNet(nn.Module):
    def __init__(self, nchannels=1, nclasses=1):
        super(UNet, self).__init__()
        self.down1 = DownBlock(2, nchannels, 64, pool=False)
        self.down2 = DownBlock(3, 64, 128)
        self.down3 = DownBlock(3, 128, 256)
        self.down4 = DownBlock(3, 256, 512)
        self.down5 = DownBlock(3, 512, 1024)
        self.up1 = UpBlock(1024, 512)
        self.up2 = UpBlock(512, 256)
        self.up3 = UpBlock(256, 128)
        self.up4 = UpBlock(128, 64)
        self.out = nn.Sequential(
            nn.Conv2d(64, nclasses, kernel_size=1)
        )

    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        return self.out(x)
  • 2
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
要使用PyTorch搭建UNet,可以按照以下步骤进行: 1. 导入必要的模块: ```python import torch import torch.nn as nn import torch.nn.functional as F ``` 2. 定义UNet的核心模块: ```python class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) ``` 3. 定义UNet的编码器: ```python class UNet(nn.Module): def __init__(self, in_channels, out_channels): super(UNet, self).__init__() self.down1 = DoubleConv(in_channels, 64) self.down2 = DoubleConv(64, 128) self.down3 = DoubleConv(128, 256) self.down4 = DoubleConv(256, 512) self.up1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) self.up3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) self.up4 = nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2) def forward(self, x): x1 = self.down1(x) x2 = self.down2(F.max_pool2d(x1, 2)) x3 = self.down3(F.max_pool2d(x2, 2)) x4 = self.down4(F.max_pool2d(x3, 2)) x = self.up1(x4) x = self.up2(torch.cat([x, x3], dim=1)) x = self.up3(torch.cat([x, x2], dim=1)) x = self.up4(torch.cat([x, x1], dim=1)) return x ``` 4. 创建UNet实例并定义输入输出通道数: ```python model = UNet(in_channels=3, out_channels=1) ``` 这是一个基本的UNet模型,你可以根据自己的需求进行修改和扩展。记得在训练之前,要根据你的任务定义损失函数和优化器。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值