U-Net 综述
U-Net 是一种用于图像分割的卷积神经网络架构,其设计旨在处理生物医学图像分割任务。U-Net 的网络结构具有对称性,包含编码器和解码器两个主要部分,并通过跳跃连接(skip connections)将两者连接起来。
U-Net 网络结构因其对称性而得名,形似英文字母 “U”。整个网络架构由蓝色和白色框表示特征图(feature map),不同颜色的箭头则代表了不同的操作和连接方式。具体而言:
- 蓝色箭头表示 3x3 卷积操作,用于特征提取,旨在捕捉输入数据中的重要特征。
- 灰色箭头表示跳跃连接(skip connection),用于特征融合,确保在解码阶段能够有效地利用编码阶段提取的高分辨率特征。
- 红色箭头表示池化操作(pooling),用于降低特征图的空间维度,从而减少计算量并提取更具抽象性的特征。
- 绿色箭头表示上采样(upsample)操作,用于恢复特征图的空间维度,以便与编码器的特征图进行拼接。
- 青色箭头表示 1x1 卷积操作,用于生成最终的输出结果。
在跳跃连接中,“copy and crop” 的过程中的 “copy” 实际上是指特征图的拼接(concatenate),而 “crop” 则是为了确保拼接的特征图在长宽上保持一致。
关于网络层数的选择,U-Net 采用了 5 层的结构,而非 4 层或 6 层。这一设计选择可能与作者在特定数据集上的实验结果有关,表明该层数在当时的任务中表现最佳。然而,这并不意味着该结构适用于所有数据集。我们应当关注的是这种编码器-解码器(Encoder-Decoder)的设计思想,而具体的实现细节应根据不同数据集的特性进行调整。
在编码器部分,网络由卷积操作和下采样操作构成。文中所采用的卷积结构统一为 3x3 的卷积核,且未使用填充(padding),步幅(striding)设置为 1。由于没有填充,特征图的高度(H)和宽度(W)在每次卷积后都会减小,因此在进行跳跃连接时需要特别注意特征图的维度匹配。为了避免维度不一致的问题,实际上可以选择在卷积操作中使用填充(padding)为 1 的设置。
U-Net 网络结构
1. 编码器(下采样路径)
- 卷积层:
- 编码器由多个卷积块组成,每个卷积块通常包含两个卷积层。每个卷积层后面跟随一个激活函数(如 ReLU)和批归一化(Batch Normalization)。
- 每个卷积层的卷积核大小通常为 3x3,填充(padding)为 1,以保持特征图的空间维度。
- 池化层:
- 在每个卷积块之后,使用最大池化(Max Pooling)层进行下采样,通常采用 2x2 的池化窗口,步幅为 2。这将特征图的空间维度减半,同时增加特征图的通道数。
- 特征提取:
- 随着网络的深入,特征图的通道数逐渐增加,通常是 64、128、256、512 等。
2. 解码器(上采样路径)
- 上采样层:
- 解码器通过上采样层逐步恢复特征图的空间分辨率。上采样可以通过转置卷积(Transpose Convolution)或双线性插值(Bilinear Interpolation)实现。
- 每次上采样后,特征图的空间维度翻倍。
- 跳跃连接:
- 在每个上采样步骤中,将对应的编码器层的特征图与解码器的特征图进行拼接(concatenation)。这种跳跃连接允许模型在上采样时保留低级特征,从而帮助恢复细节信息。
- 卷积层:
- 在解码器的每个上采样步骤后,通常会有一个卷积层来进一步处理拼接后的特征图,以减少通道数并提取特征。
3. 输出层
- 1x1 卷积:
- 最后,解码器的输出通过一个 1x1 卷积层进行处理,以生成与输入图像相同尺寸的输出特征图。输出特征图的通道数通常与分割任务的类别数相同。
4. U-Net 的整体结构示意图
Input Image
|
v
[Conv2D + ReLU] -> [Conv2D + ReLU] -> [Max Pooling]
|
v
[Conv2D + ReLU] -> [Conv2D + ReLU] -> [Max Pooling]
|
v
[Conv2D + ReLU] -> [Conv2D + ReLU] -> [Max Pooling]
|
v
[Conv2D + ReLU] -> [Conv2D + ReLU] -> [Max Pooling]
|
v
[Conv2D + ReLU] -> [Conv2D + ReLU]
|
v
[UpSampling] + [Skip Connection] -> [Conv2D + ReLU] -> [Conv2D + ReLU]
|
v
[UpSampling] + [Skip Connection] -> [Conv2D + ReLU] -> [Conv2D + ReLU]
|
v
[UpSampling] + [Skip Connection] -> [Conv2D + ReLU] -> [Conv2D + ReLU]
|
v
[UpSampling] + [Skip Connection] -> [Conv2D + ReLU] -> [Conv2D + ReLU]
|
v
[1x1 Conv] -> Output Segmentation Map
5. 代码 实现
下面是 U-Net 的实现代码,包含了编码器、解码器和跳跃连接的结构。
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, dropout_prob=0.5):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.LeakyReLU(negative_slope=0.01)
self.dropout = nn.Dropout(dropout_prob)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.dropout(x)
return x
class UpConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
self.relu = nn.ReLU()
def forward(self, x):
x = self.upconv(x)
x = self.relu(x)
return x
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
# 编码器部分
self.enc1 = ConvBlock(in_channels, 64)
self.enc2 = ConvBlock(64, 128)
self.enc3 = ConvBlock(128, 256)
self.enc4 = ConvBlock(256, 512)
# 中间层
self.bottleneck = ConvBlock(512, 1024)
# 解码器部分
self.dec4 = UpConvBlock(1024, 512)
self.dec4_conv = ConvBlock(1024, 512)
self.dec3 = UpConvBlock(512, 256)
self.dec3_conv = ConvBlock(512, 256)
self.dec2 = UpConvBlock(256, 128)
self.dec2_conv = ConvBlock(256, 128)
self.dec1 = UpConvBlock(128, 64)
self.dec1_conv = ConvBlock(128, 64)
# 输出层
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
print(f"Input shape: {x.shape}")
# 编码器部分
enc1 = self.enc1(x)
print(f"Output shape after enc1: {enc1.shape}")
enc1_pooled = F.max_pool2d(enc1, kernel_size=2)
print(f"Output shape after max_pool2d (enc1): {enc1_pooled.shape}")
enc2 = self.enc2(enc1_pooled)
print(f"Output shape after enc2: {enc2.shape}")
enc2_pooled = F.max_pool2d(enc2, kernel_size=2)
print(f"Output shape after max_pool2d (enc2): {enc2_pooled.shape}")
enc3 = self.enc3(enc2_pooled)
print(f"Output shape after enc3: {enc3.shape}")
enc3_pooled = F.max_pool2d(enc3, kernel_size=2)
print(f"Output shape after max_pool2d (enc3): {enc3_pooled.shape}")
enc4 = self.enc4(enc3_pooled)
print(f"Output shape after enc4: {enc4.shape}")
enc4_pooled = F.max_pool2d(enc4, kernel_size=2)
print(f"Output shape after max_pool2d (enc4): {enc4_pooled.shape}")
# 中间层
bottleneck = self.bottleneck(enc4_pooled)
print(f"Output shape after bottleneck: {bottleneck.shape}")
# 解码器部分
dec4 = self.dec4(bottleneck)
print(f"Output shape after dec4 (before resizing): {dec4.shape}")
dec4_resized = F.interpolate(dec4, size=enc4.shape[2:], mode='bilinear', align_corners=True)
print(f"Output shape after resizing dec4: {dec4_resized.shape}")
dec4 = torch.cat((dec4_resized, enc4), dim=1)
print(f"Output shape after concat dec4: {dec4.shape}")
dec4 = self.dec4_conv(dec4)
print(f"Output shape after dec4_conv: {dec4.shape}")
dec3 = self.dec3(dec4)
print(f"Output shape after dec3 (before resizing): {dec3.shape}")
dec3_resized = F.interpolate(dec3, size=enc3.shape[2:], mode='bilinear', align_corners=True)
print(f"Output shape after resizing dec3: {dec3_resized.shape}")
dec3 = torch.cat((dec3_resized, enc3), dim=1)
print(f"Output shape after concat dec3: {dec3.shape}")
dec3 = self.dec3_conv(dec3)
print(f"Output shape after dec3_conv: {dec3.shape}")
dec2 = self.dec2(dec3)
print(f"Output shape after dec2 (before resizing): {dec2.shape}")
dec2_resized = F.interpolate(dec2, size=enc2.shape[2:], mode='bilinear', align_corners=True)
print(f"Output shape after resizing dec2: {dec2_resized.shape}")
dec2 = torch.cat((dec2_resized, enc2), dim=1)
print(f"Output shape after concat dec2: {dec2.shape}")
dec2 = self.dec2_conv(dec2)
print(f"Output shape after dec2_conv: {dec2.shape}")
dec1 = self.dec1(dec2)
print(f"Output shape after dec1 (before resizing): {dec1.shape}")
dec1_resized = F.interpolate(dec1, size=enc1.shape[2:], mode='bilinear', align_corners=True)
print(f"Output shape after resizing dec1: {dec1_resized.shape}")
dec1 = torch.cat((dec1_resized, enc1), dim=1)
print(f"Output shape after concat dec1: {dec1.shape}")
dec1 = self.dec1_conv(dec1)
print(f"Output shape after dec1_conv: {dec1.shape}")
# 最后输出层
return self.final_conv(dec1)
# 示例用法
if __name__ == '__main__':
batch_size = 4
model = UNet(in_channels=1, out_channels=2) # 输入通道为1,输出通道为2(如二分类)
x = torch.randn(batch_size, 1, 572, 572) # 示例输入
preds = model(x)
print(f"Final output shape: {preds.shape}") # 输出形状
程序运行结果
Input shape: torch.Size([4, 1, 572, 572])
Output shape after enc1: torch.Size([4, 64, 572, 572])
Output shape after max_pool2d (enc1): torch.Size([4, 64, 286, 286])
Output shape after enc2: torch.Size([4, 128, 286, 286])
Output shape after max_pool2d (enc2): torch.Size([4, 128, 143, 143])
Output shape after enc3: torch.Size([4, 256, 143, 143])
Output shape after max_pool2d (enc3): torch.Size([4, 256, 71, 71])
Output shape after enc4: torch.Size([4, 512, 71, 71])
Output shape after max_pool2d (enc4): torch.Size([4, 512, 35, 35])
Output shape after bottleneck: torch.Size([4, 1024, 35, 35])
Output shape after dec4 (before resizing): torch.Size([4, 512, 70, 70])
Output shape after resizing dec4: torch.Size([4, 512, 71, 71])
Output shape after concat dec4: torch.Size([4, 1024, 71, 71])
Output shape after dec4_conv: torch.Size([4, 512, 71, 71])
Output shape after dec3 (before resizing): torch.Size([4, 256, 142, 142])
Output shape after resizing dec3: torch.Size([4, 256, 143, 143])
Output shape after concat dec3: torch.Size([4, 512, 143, 143])
Output shape after dec3_conv: torch.Size([4, 256, 143, 143])
Output shape after dec2 (before resizing): torch.Size([4, 128, 286, 286])
Output shape after resizing dec2: torch.Size([4, 128, 286, 286])
Output shape after concat dec2: torch.Size([4, 256, 286, 286])
Output shape after dec2_conv: torch.Size([4, 128, 286, 286])
Output shape after dec1 (before resizing): torch.Size([4, 64, 572, 572])
Output shape after resizing dec1: torch.Size([4, 64, 572, 572])
Output shape after concat dec1: torch.Size([4, 128, 572, 572])
Output shape after dec1_conv: torch.Size([4, 64, 572, 572])
Final output shape: torch.Size([4, 2, 572, 572])
注意事项
- 卷积层的定义:
- 在解码器部分,确保每个阶段的卷积层都被正确定义,以便在跳跃连接后进行两个卷积操作。
- 跳跃连接:
- 在进行跳跃连接之前,使用
F.interpolate
调整解码器输出的空间维度,以确保与编码器的特征图匹配。
- 在进行跳跃连接之前,使用
- 激活函数和正则化:
- 使用
LeakyReLU
和BatchNorm2d
可以提高模型的稳定性和性能。根据需要调整Dropout
的概率,以防止过拟合。
- 使用