简述
-
ResNet是一种非常成功的深度卷积神经网络结构,其具有较强的特征表达能力和较浅的网络深度,使得其在图像分类等任务中表现出了出色的性能。因此,将ResNet作为encoder替换U-Net原始结构,可以使U-Net在图像分割任务中获得更好的性能表现。
-
U-Net是一种经典的深度卷积神经网络结构,特别适用于图像分割任务。U-Net提出的时间较早,当时并没有像ResNet等网络结构和大规模预训练权重这样的资源可用。但是,U-Net的下采样和上采样的设计思路和现在许多成熟的网络结构相似,因此,可以看作是先驱性的工作。
-
U-Net的下采样和上采样的设计与许多现在成熟的网络结构异曲同工。具体地,U-Net的下采样部分使用卷积和池化操作来逐渐减小特征图的尺寸和通道数,提取低级别特征。特征图每一层的尺寸会降低一半,而通道数会翻倍。这种设计与现在许多成熟的网络结构(如ResNet、VGG等)的下采样部分使用卷积和池化操作来逐渐减小特征图的尺寸和通道数的设计思路相似。
-
成熟的网络结构和ImageNet预训练权重可以用来finetuning我们的U-Net。因为ImageNet是一个大规模的图像分类数据集,ImageNet预训练权重可以帮助我们在U-Net的训练中使用更好的初始化权重,加快网络的收敛速度并提高网络的泛化能力。通过finetuning,我们可以进一步优化U-Net网络的效果。
-
ResUNet是一种基于残差连接的深度学习模型,用于图像分割任务。它结合了ResNet和U-Net的优点,能够更好地解决梯度消失问题和语义信息缺失问题。下面将介绍ResUNet的原理。
原理
ResNet
Residual Network,是一种深度卷积神经网络
基于残差连接的思想,使得网络更容易训练
残差连接:跨层连接方法,可以使得网络更好地学习到低频信息
U-Net
一种全卷积网络,用于图像分割任务
包含编码器和解码器,能够提取全局和局部特征
上下文信息融合,可以缓解语义信息缺失问题
ResUNet
在U-Net的基础上加入了残差连接
编码器和解码器中的每个模块都包含了多个残差连接
在每个残差块中,引入了shortcut(或者称为skip connection)实现跨层连接
ResUNet的优点
残差连接可以缓解深度网络的梯度消失问题
编码器和解码器中的残差块可以更好地提取低频信息
上下文信息融合可以缓解语义信息缺失问题
实验结果表明,与U-Net相比,ResUNet在分割准确率上有显著提高。
总之,ResUNet是一种基于残差连接的深度学习模型,结合了ResNet和U-Net的优点。通过在编码器和解码器中引入多个残差块和shortcut,可以更好地提取低频信息、缓解语义信息缺失问题和梯度消失问题,从而提高图像分割的准确率。
示意图
可以看到,ResUNet包括了下采样、上采样和跳跃连接三个部分。
下采样部分使用卷积和池化操作逐渐减小图像尺寸和特征数量,提取低级别特征。
上采样部分使用转置卷积操作逐渐增大图像尺寸和特征数量,同时进行特征融合,生成高级别特征。
跳跃连接部分将下采样和上采样过程中相同分辨率的特征进行连接,帮助网络更好地捕捉多尺度信息,提高图像分割性能。
整个网络结构包括了ResNet的残差连接和U-Net的上下采样和跳跃连接思想,可以更好地平衡特征的丰富性和细节的保留性,在图像分割任务中表现出较好的性能。
代码实现
from torch import nn
import torchvision.models as models
import torch.nn.functional as F
from torchsummary import summary
# 定义解码器中的卷积块
class expansive_block(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels):
super(expansive_block, self).__init__()
# 卷积块的结构
self.block = nn.Sequential(
nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=mid_channels, padding=1),
nn.ReLU(),
nn.BatchNorm2d(mid_channels),
nn.Conv2d(kernel_size=(3, 3), in_channels=mid_channels, out_channels=out_channels, padding=1),
nn.ReLU(),
nn.BatchNorm2d(out_channels)
)
def forward(self, d, e=None):
# 上采样
d = F.interpolate(d, scale_factor=2, mode='bilinear', align_corners=True)
# 拼接
if e is not None:
cat = torch.cat([e, d], dim=1)
out = self.block(cat)
else:
out = self.block(d)
return out
# 定义最后一层卷积块
def final_block(in_channels, out_channels):
block = nn.Sequential(
nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=out_channels, padding=1),
nn.ReLU(),
nn.BatchNorm2d(out_channels),
)
return block
# 定义 Resnet34_Unet 类
class Resnet34_Unet(nn.Module):
# 定义初始化函数
def __init__(self, in_channel, out_channel, pretrained=False):
# 调用 nn.Module 的初始化函数
super(Resnet34_Unet, self).__init__()
# 创建 ResNet34 模型
self.resnet = models.resnet34(pretrained=pretrained)
# 定义 layer0,包括 ResNet34 的第一层卷积、批归一化、ReLU 和最大池化层
self.layer0 = nn.Sequential(
self.resnet.conv1,
self.resnet.bn1,
self.resnet.relu,
self.resnet.maxpool
)
# 定义 Encode 部分,包括 ResNet34 的 layer1、layer2、layer3 和 layer4
self.layer1 = self.resnet.layer1
self.layer2 = self.resnet.layer2
self.layer3 = self.resnet.layer3
self.layer4 = self.resnet.layer4
# 定义 Bottleneck 部分,包括两个卷积层、ReLU、批归一化和最大池化层
self.bottleneck = torch.nn.Sequential(
nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=1024, padding=1),
nn.ReLU(),
nn.BatchNorm2d(1024),
nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, padding=1),
nn.ReLU(),
nn.BatchNorm2d(1024),
nn.MaxPool2d(kernel_size=(2, 2), stride=2)
)
# 定义 Decode 部分,包括四个 expansive_block 和一个 final_block
self.conv_decode4 = expansive_block(1024+512, 512, 512)
self.conv_decode3 = expansive_block(512+256, 256, 256)
self.conv_decode2 = expansive_block(256+128, 128, 128)
self.conv_decode1 = expansive_block(128+64, 64, 64)
self.conv_decode0 = expansive_block(64, 32, 32)
self.final_layer = final_block(32, out_channel)
# 定义前向传播函数
def forward(self, x):
# 执行 layer0
x = self.layer0(x)
# 执行 Encode
encode_block1 = self.layer1(x)
encode_block2 = self.layer2(encode_block1)
encode_block3 = self.layer3(encode_block2)
encode_block4 = self.layer4(encode_block3)
# 执行 Bottleneck
bottleneck = self.bottleneck(encode_block4)
# 执行 Decode
decode_block4 = self.conv_decode4(bottleneck, encode_block4)
decode_block3 = self.conv_decode3(decode_block4, encode_block3)
decode_block2 = self.conv_decode2(decode_block3, encode_block2)
decode_block1 = self.conv_decode1(decode_block2, encode_block1)
decode_block0 = self.conv_decode0(decode_block1)
final_layer = self.final_layer(decode_block0)
return final_layer
flag = 0
if flag:
image = torch.rand(1, 3, 572, 572)
Resnet34_Unet = Resnet34_Unet(in_channel=3, out_channel=1)
mask = Resnet34_Unet(image)
print(mask.shape)
# 测试网络
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Resnet34_Unet(in_channel=1, out_channel=1, pretrained=True).to(device)
summary(model, input_size=(3, 512, 512))
该代码定义了一个基于ResNet34的Unet网络,用于语义分割任务。主要包括以下几个部分:
- expansive_block:扩张块,由两个卷积层、ReLU和BatchNorm组成。用于解码过程中对图像进行上采样和特征融合操作。
- final_block:最终块,由一个卷积层、ReLU和BatchNorm组成。用于将解码后的特征图转换为最终的输出图像。
- Resnet34_Unet:整个网络的主体部分。首先使用ResNet34作为编码器,对输入图像进行特征提取。然后通过一个卷积层和ReLU,将编码器的输出进行特征扩张。接下来进行解码操作,使用扩张块和编码器的特征图进行上采样和特征融合,直到得到与原始输入图像大小相同的特征图。最后通过最终块将特征图转换为输出图像
- flag:用于测试代码,如果设置为1,则会生成一个随机输入图像,并输出对应的分割结果。
- 测试网络:实例化Resnet34_Unet网络,并使用torchsummary库输出网络结构的信息,包括每一层的输出形状和参数数量等。