UNETR:用于3D图像分割的Transformer

论文:UNETR: Transformers for 3D Medical Image
Segmentation
论文地址:https://arxiv.org/pdf/2103.10504.pdf
摘要
近年来,具有收缩路径和扩展路径(例如,编码器和解码器)的全卷积神经网络(FCNN)在各种医学图像分割应用中表现出突出的地位。在这些体系结构中,编码器通过学习全局上下文信息成为一个不可或缺的角色,而此过程中获取的全局上下文表示形式将被解码器进一步用于语义输出预测。尽管取得了成功,但作为FCNN的主要构建模块的卷积层的局限性,限制了在此类网络中学习远程空间相关性的能力。受自然语言处理(NLP)转换器在远程序列学习中的最新成功的启发,我们将体积(3D)医学图像分割的任务重新设计为序列到序列的预测问题。特别是,我们介绍了一种称为UNEt变压器(UNETR)的新颖架构,该架构利用纯变压器作为编码器来学习输入量的序列表示并有效地捕获全局多尺度信息。转换器编码器通过不同分辨率的跳跃连接直接连接到解码器,以计算最终的语义分段输出。我们已经使用医学分割十项全能(MSD)数据集广泛验证了我们提出的模型在不同成像方式(即MR和CT)上对体积脑肿瘤和脾脏分割任务的性能,并且我们的结果始终证明了良好的基准。

1 引言
医学图像分割在许多临床诊断方法中起着不可或缺的作用,并且通常是对解剖结构进行定量分析的第一步。 自从深度学习问世以来,FCNN,尤其是编码器-解码器体系结构[19,13,14,12]在各种医学语义分割任务[1,22,11]中已取得了最新的成果。 在典型的U-Net [21]体系结构中,

  • 编码器负责通过逐渐降低采样的特征来学习全局上下文表示;
  • 解码器则将采样的表示上采样至输入分辨率,以进行像素/体素语义预测;
  • 另外,跳跃连接以不同的分辨率合并了编码器和解码器的输出,因此可以恢复在下采样期间丢失的空间信息。
    尽管此类基于FCNN的方法具有强大的表示学习功能,但它们在学习远程依存关系方面的性能仅限于其局部接收域。结果,在捕获多尺度上下文信息中的这种缺陷导致具有可变形状和尺度(例如,具有不同大小的脑损伤)的结构的分割。通过采用无意识的卷积层[4,15,10],已经做出了一些努力来缓解这个问题。但是,由于CNN的局限性,它们的接收场仍然局限于一个小区域。
    在NLP领域,基于变压器的模型[24,6]在各种任务中都达到了最新的基准。变压器中的自我注意机制使他们能够动态地突出显示单词序列的重要特征并了解其长期依赖性。最近,通过引入Visual Transformer(ViT)[7],该概念已扩展到计算机视觉。在ViT中,图像表示为一系列补丁嵌入,这些补丁嵌入将用于直接预测类别标签以进行图像分类。
    在这项工作中,我们建议利用变压器进行体积医学图像分割,并为此目的引入一种被称为UNETR的新颖架构。特别是,我们将3D分割的任务重新设计为1D序列到序列的预测问题,并使用纯转换器作为编码器从嵌入的输入色块中学习上下文信息。从变压器编码器提取的表示通过多个分辨率的跳过连接与解码器合并,以预测分段输出。
    我们已经在MSD数据集中广泛验证了我们的UNETR对脑肿瘤和脾脏分割任务的有效性[22],并且与我们的验证集中的其他模型相比,我们的实验证明了良好的性能。据我们所知,我们是第一个提出用于体积医学图像分割的完全基于变压器的编码器的公司。考虑到体积数据在医学成像中的盛行及其在分割中的广泛应用,我们认为我们的UNETR为可用于各种应用的新型基于变压器的分割模型铺平了道路。

2 相关工作
基于CNN的分割网络:自从开创性的U-Net [21]以来,基于CNN的网络已在各种2D和3D各种医学图像分割任务上取得了最新的成果[8,29,25,9 , 16,28]。 尽管取得了成功,但这些网络的局限性在于它们在学习全局上下文和长期空间依赖方面的表现不佳,这可能严重影响具有挑战性的任务的分割性能。
Visual Transformers:Visual Transformers最近在各种计算机视觉任务中获得了关注。 Dosovitskiy等。 [7]通过对纯变压器的大规模预训练和微调,在图像分类数据集上展示了最新的性能。在目标检测中,基于端到端变压器的模型在多个基准测试中显示出突出的地位[2,30]。最近,一些努力[27,3,23,26]已经探索了使用基于变压器的模型进行2D图像分割的可能性。 Chen等。 [3]通过在U-Net的瓶颈中采用变压器作为层,提出了一种用于多器官分割的2D方法。另外,Zhang等。 [26]建议在分开的流中使用CNN和变压器,并对它们的输出进行融合。瓦拉纳拉苏(Valanarasu)等。 [23]提出了一种基于变压器的轴向注意机制,用于2D医学图像分割。
我们的模型与这些工作之间存在三个主要区别:
(1)UNETR专为3D分割量身定制,并直接利用体积数据;
(2)UNETR使用变压器作为分段网络的主要编码器,并通过跳过连接将其直接连接到解码器,而不是将其用作分段网络中的关注层;
(3)UNETR不依赖于主干CNN来生成输入序列,而是直接利用标记化补丁。

3 方法论
3.1 架构
我们在图1中介绍了所提出模型的概述。UNETR利用收缩-扩展模式,该模式由一堆变压器组成,作为编码器,该编码器通过跳过连接与解码器连接。我们首先描述变压器编码器的工作机制。正如NLP中常用的那样,这些变压器以一维输入嵌入序列工作。在我们的UNETR中,通过将3D输入体积 x ∈ R H × W × D × C 3 x∈R^{H×W×D×C^{3}} xRH×W×D×C3分成平坦的均匀非重叠面片xv∈RL×C×N来创建一维序列,其中(N,N,N)表示每个补丁的维数,L =(H×W×D)/ N3是序列的长度。
如上图,如附图所示。1.对UNETR体系结构的概述。我们提取变压器中不同层的序列表示,并通过跳过连接将它们与解码器合并。补丁尺寸N=16和嵌入尺寸C=768的输出尺寸。
图1.对UNETR体系结构的概述。我们提取变压器中不同层的序列表示,并通过跳跃连接将它们与解码器合并。补丁尺寸N=16和嵌入尺寸C=768的输出尺寸。

随后,我们使用线性层将展平的贴片投影到K维嵌入空间中,该空间在整个变压器中保持恒定。此外,为了保留提取的补丁的空间信息,我们根据投影的嵌入 E ∈ R L 2 × C × K E∈R^{L^{2}×C×K} ERL2×C×K,将一维可学习位置嵌入 E p o s ∈ R L × D Epos∈R^{L×D} EposRL×D添加到投影的嵌入 E ∈ R L 2 × C × K E∈R^{L^{2}×C×K} ERL2×C×K中。
在这里插入图片描述
在嵌入层之后,我们根据以下内容利用了由多头自注意力(MSA)多层感知器(MLP)子层组成的一堆变压器块[24,7]
在这里插入图片描述
在范数表示层归一化的情况下,MLP由具有GELU激活函数的两个线性层组成, i i i是我们当前设置中的中间块标识符,范围从1到T = 12个总块。 一个MSA块包括n个并行的自我注意(SA)头。 (SA)块是一个参数化函数,用于学习输入序列(z)中两个元素及其查询(q)和键(k)表示形式的集合之间的相似性。 因此,(SA)的输出计算如下
在这里插入图片描述
其中 v v v表示输入序列中的值,而 C h = C / n C_{h} = C / n Ch=C/n是比例因子。 此外,MSA的输出定义为
在这里插入图片描述
其中 W m s a W_{msa} Wmsa代表不同头(SA)的可学习重量矩阵。
受类似UNet的体系结构的启发,将来自编码器多种分辨率的特征与解码器合并,我们提取了序列表示 z i ( i ∈ 3 , 6 , 9 , 12 ) z_{i}(i∈{3,6,9,12}) zii3,6,9,12,从变压器变形并重整为张量。 如果我们将其定义重塑为互感器的输出,并且特征尺寸为C(即互感器的嵌入尺寸),则该定义在嵌入空间中。 因此,我们通过利用连续的3×3×3卷积层,然后通过批处理归一化,将重整后的张量从嵌入空间投影到输入空间中(有关详细信息,请参见图1)。
在编码器的瓶颈处(即变压器最后一层的输出),我们将
反卷积层应用于变换后的特征图,以将其分辨率提高2倍
。然后,将调整大小后的特征图与上一个变压器的特征图连接起来输出(例如z9),将它们输入到连续的3×3×3卷积层中,并使用反卷积层对输出进行上采样。 对所有其他后续层重复此过程,直到达到原始输入分辨率为止,在此过程中,最终输出将被馈送到具有softmax激活函数的1×1×1卷积层中,以生成按像素的语义预测。
3.2 损失函数
我们的损失函数是Dice[18]和交叉熵项的组合,可以根据以下方法以体素方式计算它们:
在这里插入图片描述

I I I是体素的数量; J J J是类别数; Y i , j Y_{i,j} Yij G i , j G_{i,j} Gi,j分别表示在体素 i i i上类别 j j j的概率输出和一热编码的真相label。

4 实验
4.1 数据集
为了涵盖各种对象和图像模态,采用我们自己的5倍交叉验证数据拆分的实验,采用了来自MSD挑战[22]的任务1(脑肿瘤MRI分割)和任务9(脾CT分割)的数据集。 对于任务1,将具有神经胶质瘤分割坏死/活动性肿瘤和水肿的地面真相标签的484个多模式多站点MRI数据(FLAIR,T1w,T1gd,T2w)的整个训练集用于模型训练。 任务1的分辨率/间距统一为1.0×1.0×1.0 mm3。 对于任务9,使用带有脾脏注释的41个CT量。 任务9中卷的分辨率/空间范围为0.613×0.613×1.50 mm3至0.977×0.977×8.0 mm3 在预处理期间,将所有体积重新采样到1.0mm的各向同性体素中。
对于具有MRI图像的任务1,使用z分数归一化对体素强度进行了预处理。 对于具有CT图像的任务9,图像的体素强度根据总前景强度的第5个和第95个百分位数归一化为[0,1]范围。 此外,任务1的问题被公式化为具有4通道输入的3类分割任务,而任务9被公式化为具有单通道输入的二进制分割任务(前景和背景)。 对于任务1和任务9,我们分别以[128,128,128]和[96,96,96]的体积随机采样输入图像。 前景/背景的随机色块以1:1的比例进行采样。
实验结果:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 8
    点赞
  • 68
    收藏
    觉得还不错? 一键收藏
  • 11
    评论
Swin Transformer是一种新型的Transformer结构,它在自然语言处理和计算机视觉领域都取得了很好的效果。在3D图像分割任务中,可以使用Swin Transformer来提取特征,然后使用U-Net结构进行分割。 以下是使用Swin Transformer进行3D图像分割的步骤: 1. 导入必要的库和模块,包括torch、torchvision、Swin Transformer和U-Net等。 2. 定义Swin Transformer编码器和U-Net解码器。编码器使用Swin Transformer提取特征,解码器使用U-Net进行分割。 3. 定义损失函数和优化器。在3D图像分割任务中,可以使用交叉熵损失函数和Adam优化器。 4. 加载数据集并进行预处理。可以使用torchvision中的transforms对数据进行预处理,例如缩放、裁剪、旋转等。 5. 训练模型。使用加载的数据集对模型进行训练,并在每个epoch结束时计算损失函数和准确率。 6. 测试模型。使用测试集对训练好的模型进行测试,并计算准确率和其他评价指标。 以下是一个示例代码,用于使用Swin Transformer进行3D图像分割: ```python import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from swin_transformer import SwinTransformer3D from unet import UNet3D # 定义Swin Transformer编码器和U-Net解码器 encoder = SwinTransformer3D() decoder = UNet3D() # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.001) # 加载数据集并进行预处理 transform = transforms.Compose([ transforms.Resize((128, 128, 128)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True) # 训练模型 for epoch in range(10): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = decoder(encoder(inputs)) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() print('[%d] loss: %.3f' % (epoch + 1, running_loss / len(trainloader))) # 测试模型 transform = transforms.Compose([ transforms.Resize((128, 128, 128)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False) correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data outputs = decoder(encoder(images)) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total)) ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值