U-Net解析:含代码实现

U-Net 综述

  U-Net 是一种用于图像分割的卷积神经网络架构,其设计旨在处理生物医学图像分割任务。U-Net 的网络结构具有对称性,包含编码器和解码器两个主要部分,并通过跳跃连接(skip connections)将两者连接起来。
  U-Net 网络结构因其对称性而得名,形似英文字母 “U”。整个网络架构由蓝色和白色框表示特征图(feature map),不同颜色的箭头则代表了不同的操作和连接方式。具体而言:

  • 蓝色箭头表示 3x3 卷积操作,用于特征提取,旨在捕捉输入数据中的重要特征。
  • 灰色箭头表示跳跃连接(skip connection),用于特征融合,确保在解码阶段能够有效地利用编码阶段提取的高分辨率特征。
  • 红色箭头表示池化操作(pooling),用于降低特征图的空间维度,从而减少计算量并提取更具抽象性的特征。
  • 绿色箭头表示上采样(upsample)操作,用于恢复特征图的空间维度,以便与编码器的特征图进行拼接。
  • 青色箭头表示 1x1 卷积操作,用于生成最终的输出结果。
      在跳跃连接中,“copy and crop” 的过程中的 “copy” 实际上是指特征图的拼接(concatenate),而 “crop” 则是为了确保拼接的特征图在长宽上保持一致。
      关于网络层数的选择,U-Net 采用了 5 层的结构,而非 4 层或 6 层。这一设计选择可能与作者在特定数据集上的实验结果有关,表明该层数在当时的任务中表现最佳。然而,这并不意味着该结构适用于所有数据集。我们应当关注的是这种编码器-解码器(Encoder-Decoder)的设计思想,而具体的实现细节应根据不同数据集的特性进行调整。
      在编码器部分,网络由卷积操作和下采样操作构成。文中所采用的卷积结构统一为 3x3 的卷积核,且未使用填充(padding),步幅(striding)设置为 1。由于没有填充,特征图的高度(H)和宽度(W)在每次卷积后都会减小,因此在进行跳跃连接时需要特别注意特征图的维度匹配。为了避免维度不一致的问题,实际上可以选择在卷积操作中使用填充(padding)为 1 的设置。
    在这里插入图片描述

U-Net 网络结构

1. 编码器(下采样路径)

  • 卷积层
    • 编码器由多个卷积块组成,每个卷积块通常包含两个卷积层。每个卷积层后面跟随一个激活函数(如 ReLU)和批归一化(Batch Normalization)。
    • 每个卷积层的卷积核大小通常为 3x3,填充(padding)为 1,以保持特征图的空间维度。
  • 池化层
    • 在每个卷积块之后,使用最大池化(Max Pooling)层进行下采样,通常采用 2x2 的池化窗口,步幅为 2。这将特征图的空间维度减半,同时增加特征图的通道数。
  • 特征提取
    • 随着网络的深入,特征图的通道数逐渐增加,通常是 64、128、256、512 等。

2. 解码器(上采样路径)

  • 上采样层
    • 解码器通过上采样层逐步恢复特征图的空间分辨率。上采样可以通过转置卷积(Transpose Convolution)或双线性插值(Bilinear Interpolation)实现。
    • 每次上采样后,特征图的空间维度翻倍。
  • 跳跃连接
    • 在每个上采样步骤中,将对应的编码器层的特征图与解码器的特征图进行拼接(concatenation)。这种跳跃连接允许模型在上采样时保留低级特征,从而帮助恢复细节信息。
  • 卷积层
    • 在解码器的每个上采样步骤后,通常会有一个卷积层来进一步处理拼接后的特征图,以减少通道数并提取特征。

3. 输出层

  • 1x1 卷积
    • 最后,解码器的输出通过一个 1x1 卷积层进行处理,以生成与输入图像相同尺寸的输出特征图。输出特征图的通道数通常与分割任务的类别数相同。

4. U-Net 的整体结构示意图

Input Image
    |
    v
[Conv2D + ReLU] -> [Conv2D + ReLU] -> [Max Pooling]
    |
    v
[Conv2D + ReLU] -> [Conv2D + ReLU] -> [Max Pooling]
    |
    v
[Conv2D + ReLU] -> [Conv2D + ReLU] -> [Max Pooling]
    |
    v
[Conv2D + ReLU] -> [Conv2D + ReLU] -> [Max Pooling]
    |
    v
[Conv2D + ReLU] -> [Conv2D + ReLU]
    |
    v
[UpSampling] + [Skip Connection] -> [Conv2D + ReLU] -> [Conv2D + ReLU]
    |
    v
[UpSampling] + [Skip Connection] -> [Conv2D + ReLU] -> [Conv2D + ReLU]
    |
    v
[UpSampling] + [Skip Connection] -> [Conv2D + ReLU] -> [Conv2D + ReLU]
    |
    v
[UpSampling] + [Skip Connection] -> [Conv2D + ReLU] -> [Conv2D + ReLU]
    |
    v
[1x1 Conv] -> Output Segmentation Map

5. 代码 实现

下面是 U-Net 的实现代码,包含了编码器、解码器和跳跃连接的结构。

import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_prob=0.5):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(negative_slope=0.01)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.dropout(x)
        return x


class UpConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.upconv(x)
        x = self.relu(x)
        return x


class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # 编码器部分
        self.enc1 = ConvBlock(in_channels, 64)
        self.enc2 = ConvBlock(64, 128)
        self.enc3 = ConvBlock(128, 256)
        self.enc4 = ConvBlock(256, 512)
        # 中间层
        self.bottleneck = ConvBlock(512, 1024)
        # 解码器部分
        self.dec4 = UpConvBlock(1024, 512)
        self.dec4_conv = ConvBlock(1024, 512)
        self.dec3 = UpConvBlock(512, 256)
        self.dec3_conv = ConvBlock(512, 256)
        self.dec2 = UpConvBlock(256, 128)
        self.dec2_conv = ConvBlock(256, 128)
        self.dec1 = UpConvBlock(128, 64)
        self.dec1_conv = ConvBlock(128, 64)
        # 输出层
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        print(f"Input shape: {x.shape}")
        # 编码器部分
        enc1 = self.enc1(x)
        print(f"Output shape after enc1: {enc1.shape}")
        enc1_pooled = F.max_pool2d(enc1, kernel_size=2)
        print(f"Output shape after max_pool2d (enc1): {enc1_pooled.shape}")

        enc2 = self.enc2(enc1_pooled)
        print(f"Output shape after enc2: {enc2.shape}")
        enc2_pooled = F.max_pool2d(enc2, kernel_size=2)
        print(f"Output shape after max_pool2d (enc2): {enc2_pooled.shape}")

        enc3 = self.enc3(enc2_pooled)
        print(f"Output shape after enc3: {enc3.shape}")
        enc3_pooled = F.max_pool2d(enc3, kernel_size=2)
        print(f"Output shape after max_pool2d (enc3): {enc3_pooled.shape}")

        enc4 = self.enc4(enc3_pooled)
        print(f"Output shape after enc4: {enc4.shape}")
        enc4_pooled = F.max_pool2d(enc4, kernel_size=2)
        print(f"Output shape after max_pool2d (enc4): {enc4_pooled.shape}")

        # 中间层
        bottleneck = self.bottleneck(enc4_pooled)
        print(f"Output shape after bottleneck: {bottleneck.shape}")

        # 解码器部分
        dec4 = self.dec4(bottleneck)
        print(f"Output shape after dec4 (before resizing): {dec4.shape}")
        dec4_resized = F.interpolate(dec4, size=enc4.shape[2:], mode='bilinear', align_corners=True)
        print(f"Output shape after resizing dec4: {dec4_resized.shape}")
        dec4 = torch.cat((dec4_resized, enc4), dim=1)
        print(f"Output shape after concat dec4: {dec4.shape}")
        dec4 = self.dec4_conv(dec4)
        print(f"Output shape after dec4_conv: {dec4.shape}")

        dec3 = self.dec3(dec4)
        print(f"Output shape after dec3 (before resizing): {dec3.shape}")
        dec3_resized = F.interpolate(dec3, size=enc3.shape[2:], mode='bilinear', align_corners=True)
        print(f"Output shape after resizing dec3: {dec3_resized.shape}")
        dec3 = torch.cat((dec3_resized, enc3), dim=1)
        print(f"Output shape after concat dec3: {dec3.shape}")
        dec3 = self.dec3_conv(dec3)
        print(f"Output shape after dec3_conv: {dec3.shape}")

        dec2 = self.dec2(dec3)
        print(f"Output shape after dec2 (before resizing): {dec2.shape}")
        dec2_resized = F.interpolate(dec2, size=enc2.shape[2:], mode='bilinear', align_corners=True)
        print(f"Output shape after resizing dec2: {dec2_resized.shape}")
        dec2 = torch.cat((dec2_resized, enc2), dim=1)
        print(f"Output shape after concat dec2: {dec2.shape}")
        dec2 = self.dec2_conv(dec2)
        print(f"Output shape after dec2_conv: {dec2.shape}")

        dec1 = self.dec1(dec2)
        print(f"Output shape after dec1 (before resizing): {dec1.shape}")
        dec1_resized = F.interpolate(dec1, size=enc1.shape[2:], mode='bilinear', align_corners=True)
        print(f"Output shape after resizing dec1: {dec1_resized.shape}")
        dec1 = torch.cat((dec1_resized, enc1), dim=1)
        print(f"Output shape after concat dec1: {dec1.shape}")
        dec1 = self.dec1_conv(dec1)
        print(f"Output shape after dec1_conv: {dec1.shape}")

        # 最后输出层
        return self.final_conv(dec1)


# 示例用法
if __name__ == '__main__':
    batch_size = 4
    model = UNet(in_channels=1, out_channels=2)  # 输入通道为1,输出通道为2(如二分类)
    x = torch.randn(batch_size, 1, 572, 572)  # 示例输入
    preds = model(x)
    print(f"Final output shape: {preds.shape}")  # 输出形状

程序运行结果

Input shape: torch.Size([4, 1, 572, 572])
Output shape after enc1: torch.Size([4, 64, 572, 572])
Output shape after max_pool2d (enc1): torch.Size([4, 64, 286, 286])
Output shape after enc2: torch.Size([4, 128, 286, 286])
Output shape after max_pool2d (enc2): torch.Size([4, 128, 143, 143])
Output shape after enc3: torch.Size([4, 256, 143, 143])
Output shape after max_pool2d (enc3): torch.Size([4, 256, 71, 71])
Output shape after enc4: torch.Size([4, 512, 71, 71])
Output shape after max_pool2d (enc4): torch.Size([4, 512, 35, 35])
Output shape after bottleneck: torch.Size([4, 1024, 35, 35])
Output shape after dec4 (before resizing): torch.Size([4, 512, 70, 70])
Output shape after resizing dec4: torch.Size([4, 512, 71, 71])
Output shape after concat dec4: torch.Size([4, 1024, 71, 71])
Output shape after dec4_conv: torch.Size([4, 512, 71, 71])
Output shape after dec3 (before resizing): torch.Size([4, 256, 142, 142])
Output shape after resizing dec3: torch.Size([4, 256, 143, 143])
Output shape after concat dec3: torch.Size([4, 512, 143, 143])
Output shape after dec3_conv: torch.Size([4, 256, 143, 143])
Output shape after dec2 (before resizing): torch.Size([4, 128, 286, 286])
Output shape after resizing dec2: torch.Size([4, 128, 286, 286])
Output shape after concat dec2: torch.Size([4, 256, 286, 286])
Output shape after dec2_conv: torch.Size([4, 128, 286, 286])
Output shape after dec1 (before resizing): torch.Size([4, 64, 572, 572])
Output shape after resizing dec1: torch.Size([4, 64, 572, 572])
Output shape after concat dec1: torch.Size([4, 128, 572, 572])
Output shape after dec1_conv: torch.Size([4, 64, 572, 572])
Final output shape: torch.Size([4, 2, 572, 572])

注意事项

  1. 卷积层的定义
    • 在解码器部分,确保每个阶段的卷积层都被正确定义,以便在跳跃连接后进行两个卷积操作。
  2. 跳跃连接
    • 在进行跳跃连接之前,使用 F.interpolate 调整解码器输出的空间维度,以确保与编码器的特征图匹配。
  3. 激活函数和正则化
    • 使用 LeakyReLUBatchNorm2d 可以提高模型的稳定性和性能。根据需要调整 Dropout 的概率,以防止过拟合。

在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值