原始U-Net模型代码

论文:1505.U-Net: Convolutional Networks for Biomedical Image Segmentation
代码: https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py

一、原始u-net 架构结构 (输入572x572x1,5层,向下采样4次):

每个蓝色框对应一个多通道特征图。通道的数量在框的顶部表示。x-y 大小在框的左下角提供。白盒代表复制的特征图。箭头表示不同的操作

在这里插入图片描述

1.0 框架代码

import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        """
        初始化函数,定义UNet模型的结构。

        参数:
        n_channels -- 输入图像的通道数,例如RGB图像的通道数为3。
        n_classes -- 输出的类别数,即分割后的图像通道数。
        bilinear -- 是否使用双线性插值进行上采样。默认值为False。
        """
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        # 初始卷积层,输入通道数为n_channels,输出通道数为64
        self.inc = DoubleConv(n_channels, 64)
        # 第一个下采样层,输入通道数为64,输出通道数为128
        self.down1 = Down(64, 128)
        # 第二个下采样层,输入通道数为128,输出通道数为256
        self.down2 = Down(128, 256)
        # 第三个下采样层,输入通道数为256,输出通道数为512
        self.down3 = Down(256, 512)
        # 根据是否使用双线性插值设置下采样因子
        factor = 2 if bilinear else 1
        # 第四个下采样层,输入通道数为512,输出通道数为1024(如果使用双线性插值,则输出通道数减半)
        self.down4 = Down(512, 1024 // factor)
        # 第一个上采样层,输入通道数为1024(或512),输出通道数为512(或256)
        self.up1 = Up(1024, 512 // factor, bilinear)
        # 第二个上采样层,输入通道数为512(或256),输出通道数为256(或128)
        self.up2 = Up(512, 256 // factor, bilinear)
        # 第三个上采样层,输入通道数为256(或128),输出通道数为128(或64)
        self.up3 = Up(256, 128 // factor, bilinear)
        # 第四个上采样层,输入通道数为128,输出通道数为64
        self.up4 = Up(128, 64, bilinear)
        # 输出卷积层,输入通道数为64,输出通道数为n_classes
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        """
        前向传播函数,定义输入x如何通过各层传递并输出结果。

        参数:
        x -- 输入的图像张量

        返回:
        logits -- 输出的类别概率张量
        """
        x1 = self.inc(x)      # 初始卷积层
        x2 = self.down1(x1)   # 第一个下采样层
        x3 = self.down2(x2)   # 第二个下采样层
        x4 = self.down3(x3)   # 第三个下采样层
        x5 = self.down4(x4)   # 第四个下采样层
        x = self.up1(x5, x4)  # 第一个上采样层,并与对应的下采样层输出拼接
        x = self.up2(x, x3)   # 第二个上采样层,并与对应的下采样层输出拼接
        x = self.up3(x, x2)   # 第三个上采样层,并与对应的下采样层输出拼接
        x = self.up4(x, x1)   # 第四个上采样层,并与对应的下采样层输出拼接
        logits = self.outc(x) # 输出卷积层
        return logits         # 返回最终的类别概率张量

1.1 DoubleConv:2个3x3卷积+relu层的实现

在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """(卷积 => [批归一化] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        """
        初始化函数,定义双卷积层结构

        参数:
        in_channels -- 输入通道数
        out_channels -- 输出通道数
        mid_channels -- 中间层通道数,如果未指定,则等于输出通道数
        """
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        # 定义双卷积层,包含两次卷积,每次卷积后跟批归一化和ReLU激活函数
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        """
        前向传播函数,输入x,输出双卷积的结果
        """
        return self.double_conv(x)

1.2 Down:向下采样+同层2次卷积 (特征图维度变为原来的一半)

在这里插入图片描述
每次向下采样后,经过2次卷积层
在这里插入图片描述
对应在架构图为
在这里插入图片描述

class Down(nn.Module):
    """通过最大池化下采样,然后进行双卷积"""

    def __init__(self, in_channels, out_channels):
        """
        初始化函数,定义下采样结构

        参数:
        in_channels -- 输入通道数
        out_channels -- 输出通道数
        """
        super().__init__()
        # 定义下采样层,包含一个2x2最大池化层,然后是一个双卷积层
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        """
        前向传播函数,输入x,输出下采样的结果
        """
        return self.maxpool_conv(x)

1.3 Up: 向上采样+拼接(特征图维度加倍+拼接之前维度)

在这里插入图片描述
每次向上采样后,要接收对应向下采样结果的特征图拼接(padding和crop是类似的方法)
拼接过程中,对称层效果最好
在这里插入图片描述
拼接完成后,经过2次卷积处理)
在这里插入图片描述
代码中up操作
在这里插入图片描述

class Up(nn.Module):
    """通过上采样,然后进行双卷积"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        """
        初始化函数,定义上采样结构

        参数:
        in_channels -- 输入通道数
        out_channels -- 输出通道数
        bilinear -- 是否使用双线性插值进行上采样,默认值为True
        """
        super().__init__()

        # 如果使用双线性插值,则使用常规卷积来减少通道数
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        """
        前向传播函数,输入x1和x2,进行上采样和拼接

        参数:
        x1 -- 来自上一级的输入特征图
        x2 -- 来自对称层的特征图,用于拼接
        """
        x1 = self.up(x1)
        # 输入是CHW格式(通道数,高度,宽度)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # 如果有填充问题,请参考以下链接
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

二、完整代码

# unet for car segement
# https://github.com/milesial/Pytorch-UNet

""" Parts of the U-Net model """

import torch
import torch.nn as nn
import torch.nn.functional as F


import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """(卷积 => [批归一化] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        """
        初始化函数,定义双卷积层结构

        参数:
        in_channels -- 输入通道数
        out_channels -- 输出通道数
        mid_channels -- 中间层通道数,如果未指定,则等于输出通道数
        """
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        # 定义双卷积层,包含两次卷积,每次卷积后跟批归一化和ReLU激活函数
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        """
        前向传播函数,输入x,输出双卷积的结果
        """
        return self.double_conv(x)


class Down(nn.Module):
    """通过最大池化下采样,然后进行双卷积"""

    def __init__(self, in_channels, out_channels):
        """
        初始化函数,定义下采样结构

        参数:
        in_channels -- 输入通道数
        out_channels -- 输出通道数
        """
        super().__init__()
        # 定义下采样层,包含一个2x2最大池化层,然后是一个双卷积层
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        """
        前向传播函数,输入x,输出下采样的结果
        """
        return self.maxpool_conv(x)


class Up(nn.Module):
    """通过上采样,然后进行双卷积"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        """
        初始化函数,定义上采样结构

        参数:
        in_channels -- 输入通道数
        out_channels -- 输出通道数
        bilinear -- 是否使用双线性插值进行上采样,默认值为True
        """
        super().__init__()

        # 如果使用双线性插值,则使用常规卷积来减少通道数
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        """
        前向传播函数,输入x1和x2,进行上采样和拼接

        参数:
        x1 -- 来自上一级的输入特征图
        x2 -- 来自对称层的特征图,用于拼接
        """
        x1 = self.up(x1)
        # 输入是CHW格式(通道数,高度,宽度)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # 如果有填充问题,请参考以下链接
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        """
        初始化函数,定义输出卷积层

        参数:
        in_channels -- 输入通道数
        out_channels -- 输出通道数
        """
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        """
        前向传播函数,输入x,输出卷积的结果
        """
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)


if __name__ == '__main__':
    input_image = torch.randn(1, 3, 572, 572)  # 假设输入图像大小为572x572,通道数为3
    unet = UNet(n_channels=3,n_classes=1)

    y=unet(input_image)
    print('y',y.shape)
    print(unet)  # 输出张量的形状

        # 创建一个BatchNorm2d层
    batch_norm = nn.BatchNorm2d(num_features=64)

    # 打印BatchNorm2d层的可学习参数
    print("Gamma (scale parameter):", batch_norm.weight.shape)
    print("Beta (shift parameter):", batch_norm.bias.shape)

    # 创建一个假数据,形状为 (batch_size, channels, height, width)
    input_tensor = torch.randn(8, 64, 32, 32)

    # 应用批量归一化
    output_tensor = batch_norm(input_tensor)

    # 打印输出形状
    print(output_tensor.shape)

附录

nn.BatchNorm2d 层相关

是 PyTorch 中用于批量归一化(Batch Normalization)的模块之一,专门用于处理 2D 图像数据。批量归一化是一种在训练深度神经网络时加速和稳定训练过程的技术。它通过标准化每个小批量的输入来减少内部协变量偏移。

具体公式

y n c h w = γ c x n c h w − μ c σ c 2 + ϵ + β c y_{nchw} = \gamma_c \frac{x_{nchw} - \mu_c}{\sqrt{\sigma_c^2 + \epsilon}} + \beta_c ynchw=γcσc2+ϵ xnchwμc+βc

参数解释

x n c h w x_{nchw} xnchw 输入特征图的向量
y n c h w y_{nchw} ynchw:批量归一化后的输出特征图的值
γ c \gamma_c γc: 可学习缩放参数(scale parameter),用于调整标准化后的特征图的分布。
β c \beta_c βc: 可学习平移参数(shift parameter),用于调整标准化后的特征图的偏移。
ϵ \epsilon ϵ:一个很小的常数,用于防止除零操作,通常取 1e−5)

计算batch=N,HxW特征图,通道C均值 μ c \mu_c μc

μ c = 1 N × H × W ∑ n = 1 N ∑ h = 1 H ∑ w = 1 W x n c h w \mu_c = \frac{1}{N \times H \times W} \sum_{n=1}^{N} \sum_{h=1}^{H} \sum_{w=1}^{W} x_{nchw} μc=N×H×W1n=1Nh=1Hw=1Wxnchw
计算均值 σ c 2 \sigma_c^2 σc2
σ c 2 = 1 N × H × W ∑ n = 1 N ∑ h = 1 H ∑ w = 1 W ( x n c h w − μ c ) 2 \sigma_c^2 = \frac{1}{N \times H \times W} \sum_{n=1}^{N} \sum_{h=1}^{H} \sum_{w=1}^{W} (x_{nchw} - \mu_c)^2 σc2=N×H×W1n=1Nh=1Hw=1W(xnchwμc)2

标准化 x ^ n c h w \hat{x}_{nchw} x^nchw,

x ^ n c h w = x n c h w − μ c σ c 2 + ϵ \hat{x}_{nchw} = \frac{x_{nchw} - \mu_c}{\sqrt{\sigma_c^2 + \epsilon}} x^nchw=σc2+ϵ xnchwμc

缩放和平移:

y n c h w = γ c x ^ n c h w + β c y_{nchw} = \gamma_c \hat{x}_{nchw} + \beta_c ynchw=γcx^nchw+βc

  • 18
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

曾小蛙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值