深度学习--图像分割UNet介绍及代码分析

参考

U-Net: Convolutional Networks for Biomedical Image Segmentation 输入文章名自行查询
参考博客

UNet网络介绍

整体架构

在这里插入图片描述

U-net 架构(以最低分辨率为 32x32 像素为例),每个蓝色框对应一个多通道特征图。通道数显示在框的顶部。x-y 大小位于框的左下边缘。白框表示复制的特征图。箭头表示不同的操作。
深蓝色箭头:利用3×3的卷积核对图片进行卷积后,通过ReLU激活函数输出特征通道;
灰色箭头:对左边下采样过程中的图片进行裁剪复制;
红色箭头:通过最大池化对图片进行下采样,池化核大小为2×2;
绿色箭头:反卷积,对图像进行上采样,卷积核大小为2×2;
青色箭头:使用1×1的卷积核对图片进行卷积。
因为网络形状像U,故被称为U-net
参考博客(网络结构很清晰,推荐)

UNet过程

输入

U-Net 的输入是一幅单通道的图像,通常大小为 572x572 像素,由于在不断valid卷积过程中,会使得图片越来越小,为了避免数据丢失,在图像输入前都需要进行镜像扩大。

编码器(下采样)

  • U-Net 的编码器部分输入图像通过卷积层进行特征提取,这些卷积层通常使用 3x3 的卷积核,逐步提取图像特征并缩小空间维度。
  • 然后,通过池化层(通常是最大池化)将图像的空间维度减小,例如从 572x572 缩小到 286x286。
  • 这个过程会重复多次,每次都会减小图像的空间维度和增加特征通道数。

中间特征表示

  • 在编码器的最后一层,我们获得了一个中间特征表示,通常是一个高维的特征张量。
  • 这个特征表示包含了图像的抽象特征,可以用于后续的分割任务。

解码器(上采样)

  • U-Net 的解码器部分将中间特征表示还原到原始的空间维度,并逐步增加分辨率。
  • 首先,通过上采样操作将特征张量的空间维度扩大,例如从 286x286 扩大到 572x572。
  • 然后,通过卷积层进行特征融合,将低级和高级特征结合起来。
  • 最后,输出通道数为 64 的卷积层将特征映射到最终的分割结果。

输出

  • U-Net 的输出是一个分割图像,大小与输入图像相同(通常为 572x572 像素),这幅分割图像被分成不同的区域,其中不同区域被分配不同的标签或类别。
  • 分割图像中的每个像素都被分类到不同的类别中,即可以准确地知道图像中的每个像素属于哪个结构或区域。这个分割图像可以用于识别生物医学图像中的不同结构,例如肿瘤、器官等。

代码详解

unetUP和Unet关系

  1. unetUp:

    • unetUp 是一个自定义的 PyTorch 模块(nn.Module),用于实现 U-Net 模型中的上采样部分。
    • 它接受两个输入特征张量 inputs1inputs2,并将它们进行上采样、特征融合和卷积操作,最终输出一个特征张量。
    • 在 U-Net 中,unetUp 负责将低分辨率的特征图上采样到与高分辨率特征图相同的尺寸,以便进行特征融合。
  2. Unet:

    • Unet 是整个 U-Net 模型的主体部分,它由多个 unetUp 模块组成。
    • 根据选择的 backbone(可以是 VGG 或 ResNet-50),Unet 使用不同的主干网络提取特征。
    • Unet 的前向传播过程包括多次特征融合,上采样和卷积操作,最终生成语义分割结果。

上采样模块——unetUp

这段代码定义了一个名为 unetUp 的类,它是一个用于UNet架构中的上采样模块。这个模块的作用是将低分辨率特征图上采样并与高分辨率特征图结合,以生成更高分辨率的输出。

# 定义unetUp类,unetUp类继承自nn.Module,是PyTorch中所有神经网络模块的基类。
class unetUp(nn.Module):
# 初始化方法
in_size 和 out_size 是输入和输出通道的数量。
self.conv1 和 self.conv2 是两个二维卷积层,卷积核大小为3,填充为1。
self.up 是一个最近邻插值的上采样层,放大倍数为2。
self.relu 是一个ReLU激活函数。

    def __init__(self, in_size, out_size):
        super(unetUp, self).__init__()
        self.conv1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1)
        self.up = nn.UpsamplingNearest2d(scale_factor=2)
        self.relu = nn.ReLU(inplace=True)
# 前向传播方法
# inputs1 和 inputs2 是前向传播时的输入张量。
# torch.cat([inputs1, self.up(inputs2)], 1) 将 inputs1 和上采样后的 inputs2 在通道维度上拼接。
# outputs = self.conv1(outputs) 对拼接后的张量进行第一次卷积。
# outputs = self.relu(outputs) 对卷积结果应用ReLU激活函数。
# outputs = self.conv2(outputs) 对激活后的张量进行第二次卷积。
# outputs = self.relu(outputs) 再次应用ReLU激活函数。最终返回处理后的 outputs。
# 这段代码的主要功能是将低分辨率特征图上采样并与高分辨率特征图结合,经过两次卷积和激活函数处理后,生成更高分辨率的输出特征图。

    def forward(self, inputs1, inputs2):
        outputs = torch.cat([inputs1,self.up(inputs2)],1)
        outputs = self.conv1(outputs)
        outputs = self.relu(outputs)
        outputs = self.conv2(outputs)
        outputs = self.relu(outputs)
        return outputs

用于图像分割的卷积神经网络(CNN)架构模块——Unet

下面这段代码定义了一个名为 Unet 的类,它是一个用于图像分割的卷积神经网络(CNN)架构。这个类可以使用不同的骨干网络(backbone),如VGG16或ResNet50,并包含上采样模块以生成高分辨率的输出。以下是对代码的详细解释:

类的定义

class Unet(nn.Module):

Unet 类继承自 nn.Module,这是PyTorch中所有神经网络模块的基类。

初始化方法

def __init__(self, num_classes=21, pretrained=False, backbone='vgg'):
    super(Unet, self).__init__()
    if backbone == 'vgg':
        self.vgg = VGG16(pretrained=pretrained)
        in_filters = [192, 384, 768, 1024]
    elif backbone == "resnet50":
        self.resnet = resnet50(pretrained=pretrained)
        in_filters = [192, 512, 1024, 3072]
    else:
        raise ValueError('Unsupported backbone - `{}`, Use vgg, resnet50.'.format(backbone))
    out_filters = [64, 128, 256, 512]
  • num_classes 是输出类别的数量。
  • pretrained 指示是否使用预训练的权重。
  • backbone 指定使用的骨干网络,可以是VGG16或ResNet50。
  • 根据选择的骨干网络,初始化相应的网络并设置输入过滤器的数量。
  • in_filters(输入通道数):在卷积神经网络(CNN)中,in_filters 表示输入图像的通道数或特征图的数量。在输入层,如果是灰度图片,那就只有一个 feature map;如果是彩色图片,一般就是 3 个 feature map(对应红、绿、蓝通道)。在其他层,每个卷积核(也称为过滤器)与上一层的每个 feature map 做卷积,产生下一层的一个 feature map。因此,如果有 N 个卷积核,下一层就会产生 N 个 feature map。
  • out_filters(输出通道数):在卷积神经网络中,out_filters 表示卷积核的数量或输出的特征图数量。卷积核的个数决定了下一层的 feature map 数量。每个卷积核可以提取一种特征,并生成一个新的特征图。在多层卷积网络中,下一层的卷积核的通道数等于上一层的 feature map 数量。如果通道数不相等,就无法继续进行卷积操作。

上采样模块

# upsampling
self.up_concat4 = unetUp(in_filters[3], out_filters[3])
self.up_concat3 = unetUp(in_filters[2], out_filters[2])
self.up_concat2 = unetUp(in_filters[1], out_filters[1])
self.up_concat1 = unetUp(in_filters[0], out_filters[0])
  • 定义四个上采样模块,每个模块将低分辨率特征图上采样到更高分辨率并与高分辨率特征图结合。
  • self.up_concat4, self.up_concat3, self.up_concat2 , self.up_concat1 是上采样操作的一部分。它们分别将不同层的特征图级联在一起,以获得更丰富的特征表示

额外的上采样卷积层(仅用于ResNet50)

if backbone == 'resnet50':
    self.up_conv = nn.Sequential(
    # 使用双线性插值(nn.UpsamplingBilinear2d)将特征图的大小放大两倍
        nn.UpsamplingBilinear2d(scale_factor=2), 
        # 通过两个卷积层对特征图进行处理,以获得更好的特征表示
        nn.Conv2d(out_filters[0], out_filters[0], kernel_size=3, padding=1),
         # 使用 nn.ReLU() 激活函数来确保非线性变换
        nn.ReLU(),
        nn.Conv2d(out_filters[0], out_filters[0], kernel_size=3, padding=1),
        nn.ReLU(),
    )
else:
    self.up_conv = None
  • 如果使用ResNet50作为骨干网络,定义一个额外的上采样卷积层。

最终卷积层

# self.final 是一个卷积层,用于生成最终的输出。它将高分辨率的特征图映射到类别数(num_classes)
self.final = nn.Conv2d(out_filters[0], num_classes, 1)
  • 定义一个最终的卷积层,将输出通道数转换为类别数。

前向传播方法

定义模型的前向传播过程,将输入数据通过网络的各层进行计算,最终生成输出。

def forward(self, inputs):
# 根据 self.backbone 的值,选择不同的模型(VGG 或 ResNet-50)进行前向传播。
    	# 通过卷积层和池化层对输入数据进行处理,得到特征图 feat1、feat2、feat3、feat4 和 feat5
    if self.backbone == "vgg":
        [feat1, feat2, feat3, feat4, feat5] = self.vgg.forward(inputs)
    elif self.backbone == "resnet50":
        [feat1, feat2, feat3, feat4, feat5] = self.resnet.forward(inputs)
# 通过上采样操作将这些特征图进行级联,得到更高分辨率的特征图 up4、up3、up2 和 up1
    up4 = self.up_concat4(feat4, feat5)
    up3 = self.up_concat3(feat3, up4)
    up2 = self.up_concat2(feat2, up3)
    up1 = self.up_concat1(feat1, up2)
# 如果存在上采样卷积层 self.up_conv,则对 up1 进行进一步处理
    if self.up_conv != None:
        up1 = self.up_conv(up1)
# 通过 self.final 层获得最终的输出
    final = self.final(up1)
    
    return final
  • 根据选择的骨干网络,获取不同层的特征图。
  • 使用上采样模块逐层上采样并结合特征图。
  • 如果定义了额外的上采样卷积层,则应用该层。
  • 最终通过一个卷积层生成输出。

冻结和解冻骨干网络

def freeze_backbone(self):
    if self.backbone == "vgg":
        for param in self.vgg.parameters():
            param.requires_grad = False
    elif self.backbone == "resnet50":
        for param in self.resnet.parameters():
            param.requires_grad = False

def unfreeze_backbone(self):
    if self.backbone == "vgg":
        for param in self.vgg.parameters():
            param.requires_grad = True
    elif self.backbone == "resnet50":
        for param in self.resnet.parameters():
            param.requires_grad = True
  • freeze_backbone 方法用于冻结骨干网络的参数,使其在训练过程中不更新。
  • unfreeze_backbone 方法用于解冻骨干网络的参数,使其在训练过程中可以更新。

冻结或解冻神经网络模型的特定层–更好地进行迁移学习或微调

  1. 迁移学习

    • 在迁移学习中,使用一个预训练的神经网络模型(通常在大规模数据集上进行训练)来解决新的任务。
    • 通过冻结模型的底层层(例如卷积层),可以保留其在原始任务上学到的特征表示,然后在新任务上进行微调。
    • 这样做有助于避免在新任务上过拟合,并且可以加快训练速度。
  2. 微调

    • 微调是指在预训练模型的基础上继续训练,以适应新任务的特定数据。
    • 解冻底层层,允许其权重在新任务上进行调整,以更好地适应新数据。
    • 通常,我们只微调模型的一部分,而不是整个模型,以避免丢失预训练模型的有用特征。

完整代码

import torch
import torch.nn as nn

from nets.resnet import resnet50
from nets.vgg import VGG16

class unetUp(nn.Module):
    def __init__(self, in_size, out_size):
        super(unetUp, self).__init__()
        self.conv1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1)
        self.up = nn.UpsamplingNearest2d(scale_factor=2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, inputs1, inputs2):
        outputs = torch.cat([inputs1,self.up(inputs2)],1)
        outputs = self.conv1(outputs)
        outputs = self.relu(outputs)
        outputs = self.conv2(outputs)
        outputs = self.relu(outputs)
        return outputs
class Unet(nn.Module):
    def __init__(self, num_classes = 2, pretrained = False, backbone = 'vgg'):
        super(Unet, self).__init__()
        if backbone == 'vgg':
            self.vgg = VGG16(pretrained=pretrained)
            in_filters = [192, 384, 768, 1024]
        elif backbone == 'resnet50'
            self.resnet = resnet50(pretrained=pretrained)
            in_filters = [192, 512, 1024, 3072]
        else:
            raise ValueError('Unsupported backbone -`{}`, Use vgg, resnet50.'.format(backbone))
        out_filters = [64, 128, 256, 512]
        #???
        self.up_concat4 = unetUp(in_filters[3], out_filters[3])
        self.up_concat3 = unetUp(in_filters[2], out_filters[2])
        self.up_concat2 = unetUp(in_filters[1], out_filters[1])
        self.up_concat1 = unetUp(in_filters[0], out_filters[0])

        if backbone == 'resnet50':
            self.up_conv = nn.Sequential(
                nn.UpsamplingNearest2d(scale_factor=2),
                nn.Conv2d(out_filters[0],out_filters[0], kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(out_filters[0], out_filters[0], kernel_size=3, padding=1),
                nn.ReLU(),
            )
        else:
            self.up_conv = None

        self.final = nn.Conv2d(out_filters[0], num_classes, 1)
        self.backbone = backbone

    def forward(self, inputs):
        if self.backbone == "vgg":
            [feat1, feat2, feat3, feat4, feat5] = self.vgg.forward(inputs)
        elif self.backbone == "resnet50":
            [feat1, feat2, feat3, feat4, feat5] = self.resnet.forward(inputs)

        up4 = self.up_concat4(feat4, feat5)
        up3 = self.up_concat3(feat3, up4)
        up2 = self.up_concat2(feat2, up3)
        up1 = self.up_concat1(feat1, up2)

        if self.up_conv != None:
            up1 = self.up_conv(up1)

        final = self.final(up1)
        return final

    def freeze_backbone(self):
        if self.backbone == "vgg":
            for param in self.vgg.parameters():
                param.requires_grad = False
        elif self.backbone == "resnet50":
            for param in self.resnet.parameters():
                param.requires_grad = False

    def unfreeze_backbone(self):
        if self.backbone == "vgg":
            for param in self.vgg.parameters():
                param.requires_grad = True
        elif self.backbone == "resnet50":
            for param in self.resnet.parameters():
                param.requires_grad = True
                
  • 24
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
这里提供一个简单的UNet代码示例,用于图像分割任务: ```python import torch import torch.nn as nn import torch.nn.functional as F class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class UNet(nn.Module): def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]): super(UNet, self).__init__() self.ups = nn.ModuleList() self.downs = nn.ModuleList() self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # Down part of UNet for feature in features: self.downs.append(DoubleConv(in_channels, feature)) in_channels = feature # Up part of UNet for feature in reversed(features): self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)) self.ups.append(DoubleConv(feature*2, feature)) self.bottleneck = DoubleConv(features[-1], features[-1]*2) self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) def forward(self, x): skip_connections = [] for down in self.downs: x = down(x) skip_connections.append(x) x = self.pool(x) x = self.bottleneck(x) skip_connections = skip_connections[::-1] for idx in range(0, len(self.ups), 2): x = self.ups[idx](x) skip_connection = skip_connections[idx//2] if x.shape != skip_connection.shape: x = F.interpolate(x, size=skip_connection.shape[2:], mode='bilinear', align_corners=True) concat_skip = torch.cat((skip_connection, x), dim=1) x = self.ups[idx+1](concat_skip) return self.final_conv(x) ``` 这里的代码实现了一个简单的UNet模型,其中包含了双卷积层和上采样层,用于图像分割任务。在构建模型时,可以根据任务需要自行修改输入、输出通道数以及特征层数组`features`。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值