论文:1505.U-Net: Convolutional Networks for Biomedical Image Segmentation
代码: https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py
一、原始u-net 架构结构 (输入572x572x1,5层,向下采样4次):
每个蓝色框对应一个多通道特征图。通道的数量在框的顶部表示。x-y 大小在框的左下角提供。白盒代表复制的特征图。箭头表示不同的操作
1.0 框架代码
import torch
import torch.nn as nn
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
"""
初始化函数,定义UNet模型的结构。
参数:
n_channels -- 输入图像的通道数,例如RGB图像的通道数为3。
n_classes -- 输出的类别数,即分割后的图像通道数。
bilinear -- 是否使用双线性插值进行上采样。默认值为False。
"""
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
# 初始卷积层,输入通道数为n_channels,输出通道数为64
self.inc = DoubleConv(n_channels, 64)
# 第一个下采样层,输入通道数为64,输出通道数为128
self.down1 = Down(64, 128)
# 第二个下采样层,输入通道数为128,输出通道数为256
self.down2 = Down(128, 256)
# 第三个下采样层,输入通道数为256,输出通道数为512
self.down3 = Down(256, 512)
# 根据是否使用双线性插值设置下采样因子
factor = 2 if bilinear else 1
# 第四个下采样层,输入通道数为512,输出通道数为1024(如果使用双线性插值,则输出通道数减半)
self.down4 = Down(512, 1024 // factor)
# 第一个上采样层,输入通道数为1024(或512),输出通道数为512(或256)
self.up1 = Up(1024, 512 // factor, bilinear)
# 第二个上采样层,输入通道数为512(或256),输出通道数为256(或128)
self.up2 = Up(512, 256 // factor, bilinear)
# 第三个上采样层,输入通道数为256(或128),输出通道数为128(或64)
self.up3 = Up(256, 128 // factor, bilinear)
# 第四个上采样层,输入通道数为128,输出通道数为64
self.up4 = Up(128, 64, bilinear)
# 输出卷积层,输入通道数为64,输出通道数为n_classes
self.outc = OutConv(64, n_classes)
def forward(self, x):
"""
前向传播函数,定义输入x如何通过各层传递并输出结果。
参数:
x -- 输入的图像张量
返回:
logits -- 输出的类别概率张量
"""
x1 = self.inc(x) # 初始卷积层
x2 = self.down1(x1) # 第一个下采样层
x3 = self.down2(x2) # 第二个下采样层
x4 = self.down3(x3) # 第三个下采样层
x5 = self.down4(x4) # 第四个下采样层
x = self.up1(x5, x4) # 第一个上采样层,并与对应的下采样层输出拼接
x = self.up2(x, x3) # 第二个上采样层,并与对应的下采样层输出拼接
x = self.up3(x, x2) # 第三个上采样层,并与对应的下采样层输出拼接
x = self.up4(x, x1) # 第四个上采样层,并与对应的下采样层输出拼接
logits = self.outc(x) # 输出卷积层
return logits # 返回最终的类别概率张量
1.1 DoubleConv:2个3x3卷积+relu层的实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(卷积 => [批归一化] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
"""
初始化函数,定义双卷积层结构
参数:
in_channels -- 输入通道数
out_channels -- 输出通道数
mid_channels -- 中间层通道数,如果未指定,则等于输出通道数
"""
super().__init__()
if not mid_channels:
mid_channels = out_channels
# 定义双卷积层,包含两次卷积,每次卷积后跟批归一化和ReLU激活函数
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
"""
前向传播函数,输入x,输出双卷积的结果
"""
return self.double_conv(x)
1.2 Down:向下采样+同层2次卷积 (特征图维度变为原来的一半)
每次向下采样后,经过2次卷积层
对应在架构图为
class Down(nn.Module):
"""通过最大池化下采样,然后进行双卷积"""
def __init__(self, in_channels, out_channels):
"""
初始化函数,定义下采样结构
参数:
in_channels -- 输入通道数
out_channels -- 输出通道数
"""
super().__init__()
# 定义下采样层,包含一个2x2最大池化层,然后是一个双卷积层
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
"""
前向传播函数,输入x,输出下采样的结果
"""
return self.maxpool_conv(x)
1.3 Up: 向上采样+拼接(特征图维度加倍+拼接之前维度)
每次向上采样后,要接收对应向下采样结果的特征图拼接(padding和crop是类似的方法)
拼接过程中,对称层效果最好
拼接完成后,经过2次卷积处理)
代码中up操作
class Up(nn.Module):
"""通过上采样,然后进行双卷积"""
def __init__(self, in_channels, out_channels, bilinear=True):
"""
初始化函数,定义上采样结构
参数:
in_channels -- 输入通道数
out_channels -- 输出通道数
bilinear -- 是否使用双线性插值进行上采样,默认值为True
"""
super().__init__()
# 如果使用双线性插值,则使用常规卷积来减少通道数
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
"""
前向传播函数,输入x1和x2,进行上采样和拼接
参数:
x1 -- 来自上一级的输入特征图
x2 -- 来自对称层的特征图,用于拼接
"""
x1 = self.up(x1)
# 输入是CHW格式(通道数,高度,宽度)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# 如果有填充问题,请参考以下链接
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
二、完整代码
# unet for car segement
# https://github.com/milesial/Pytorch-UNet
""" Parts of the U-Net model """
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(卷积 => [批归一化] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
"""
初始化函数,定义双卷积层结构
参数:
in_channels -- 输入通道数
out_channels -- 输出通道数
mid_channels -- 中间层通道数,如果未指定,则等于输出通道数
"""
super().__init__()
if not mid_channels:
mid_channels = out_channels
# 定义双卷积层,包含两次卷积,每次卷积后跟批归一化和ReLU激活函数
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
"""
前向传播函数,输入x,输出双卷积的结果
"""
return self.double_conv(x)
class Down(nn.Module):
"""通过最大池化下采样,然后进行双卷积"""
def __init__(self, in_channels, out_channels):
"""
初始化函数,定义下采样结构
参数:
in_channels -- 输入通道数
out_channels -- 输出通道数
"""
super().__init__()
# 定义下采样层,包含一个2x2最大池化层,然后是一个双卷积层
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
"""
前向传播函数,输入x,输出下采样的结果
"""
return self.maxpool_conv(x)
class Up(nn.Module):
"""通过上采样,然后进行双卷积"""
def __init__(self, in_channels, out_channels, bilinear=True):
"""
初始化函数,定义上采样结构
参数:
in_channels -- 输入通道数
out_channels -- 输出通道数
bilinear -- 是否使用双线性插值进行上采样,默认值为True
"""
super().__init__()
# 如果使用双线性插值,则使用常规卷积来减少通道数
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
"""
前向传播函数,输入x1和x2,进行上采样和拼接
参数:
x1 -- 来自上一级的输入特征图
x2 -- 来自对称层的特征图,用于拼接
"""
x1 = self.up(x1)
# 输入是CHW格式(通道数,高度,宽度)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# 如果有填充问题,请参考以下链接
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
"""
初始化函数,定义输出卷积层
参数:
in_channels -- 输入通道数
out_channels -- 输出通道数
"""
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
"""
前向传播函数,输入x,输出卷积的结果
"""
return self.conv(x)
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = (DoubleConv(n_channels, 64))
self.down1 = (Down(64, 128))
self.down2 = (Down(128, 256))
self.down3 = (Down(256, 512))
factor = 2 if bilinear else 1
self.down4 = (Down(512, 1024 // factor))
self.up1 = (Up(1024, 512 // factor, bilinear))
self.up2 = (Up(512, 256 // factor, bilinear))
self.up3 = (Up(256, 128 // factor, bilinear))
self.up4 = (Up(128, 64, bilinear))
self.outc = (OutConv(64, n_classes))
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
def use_checkpointing(self):
self.inc = torch.utils.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint(self.down2)
self.down3 = torch.utils.checkpoint(self.down3)
self.down4 = torch.utils.checkpoint(self.down4)
self.up1 = torch.utils.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint(self.up2)
self.up3 = torch.utils.checkpoint(self.up3)
self.up4 = torch.utils.checkpoint(self.up4)
self.outc = torch.utils.checkpoint(self.outc)
if __name__ == '__main__':
input_image = torch.randn(1, 3, 572, 572) # 假设输入图像大小为572x572,通道数为3
unet = UNet(n_channels=3,n_classes=1)
y=unet(input_image)
print('y',y.shape)
print(unet) # 输出张量的形状
# 创建一个BatchNorm2d层
batch_norm = nn.BatchNorm2d(num_features=64)
# 打印BatchNorm2d层的可学习参数
print("Gamma (scale parameter):", batch_norm.weight.shape)
print("Beta (shift parameter):", batch_norm.bias.shape)
# 创建一个假数据,形状为 (batch_size, channels, height, width)
input_tensor = torch.randn(8, 64, 32, 32)
# 应用批量归一化
output_tensor = batch_norm(input_tensor)
# 打印输出形状
print(output_tensor.shape)
附录
nn.BatchNorm2d 层相关
是 PyTorch 中用于批量归一化(Batch Normalization)的模块之一,专门用于处理 2D 图像数据。批量归一化是一种在训练深度神经网络时加速和稳定训练过程的技术。它通过标准化每个小批量的输入来减少内部协变量偏移。
具体公式
y n c h w = γ c x n c h w − μ c σ c 2 + ϵ + β c y_{nchw} = \gamma_c \frac{x_{nchw} - \mu_c}{\sqrt{\sigma_c^2 + \epsilon}} + \beta_c ynchw=γcσc2+ϵxnchw−μc+βc
参数解释
x
n
c
h
w
x_{nchw}
xnchw 输入特征图的向量
y
n
c
h
w
y_{nchw}
ynchw:批量归一化后的输出特征图的值
γ
c
\gamma_c
γc: 可学习
的缩放参数(scale parameter),用于调整标准化后的特征图的分布。
β
c
\beta_c
βc: 可学习
的平移参数(shift parameter),用于调整标准化后的特征图的偏移。
ϵ
\epsilon
ϵ:一个很小的常数,用于防止除零操作,通常取 1e−5)
计算batch=N,HxW特征图,通道C均值 μ c \mu_c μc
μ
c
=
1
N
×
H
×
W
∑
n
=
1
N
∑
h
=
1
H
∑
w
=
1
W
x
n
c
h
w
\mu_c = \frac{1}{N \times H \times W} \sum_{n=1}^{N} \sum_{h=1}^{H} \sum_{w=1}^{W} x_{nchw}
μc=N×H×W1n=1∑Nh=1∑Hw=1∑Wxnchw
计算均值
σ
c
2
\sigma_c^2
σc2,
σ
c
2
=
1
N
×
H
×
W
∑
n
=
1
N
∑
h
=
1
H
∑
w
=
1
W
(
x
n
c
h
w
−
μ
c
)
2
\sigma_c^2 = \frac{1}{N \times H \times W} \sum_{n=1}^{N} \sum_{h=1}^{H} \sum_{w=1}^{W} (x_{nchw} - \mu_c)^2
σc2=N×H×W1n=1∑Nh=1∑Hw=1∑W(xnchw−μc)2
标准化 x ^ n c h w \hat{x}_{nchw} x^nchw,
x ^ n c h w = x n c h w − μ c σ c 2 + ϵ \hat{x}_{nchw} = \frac{x_{nchw} - \mu_c}{\sqrt{\sigma_c^2 + \epsilon}} x^nchw=σc2+ϵxnchw−μc
缩放和平移:
y n c h w = γ c x ^ n c h w + β c y_{nchw} = \gamma_c \hat{x}_{nchw} + \beta_c ynchw=γcx^nchw+βc