医图顶会 MICCAI'24 | LKM-UNet: 用于医学图像分割的大内核视觉Mamba UNet

点击下方“ReadingPapers”卡片,每天获取顶刊论文解读

论文信息

题目:LKM-UNet: Large Kernel Vision Mamba UNet for Medical Image Segmentation
LKM-UNet: 大型内核视觉 Mamba UNet 用于医学图像分割
作者:Jinhong Wang, Jintai Chen, Danny Chen, Jian Wu
源码链接:https://github.com/wjh892521292/LKM-UNet

本文创新点

  1. 提出了大型内核 Mamba U 形网络 (LKM-UNet):作者提出了一种新的基于 Mamba 的 UNet 模型,用于 2D 和 3D 医学图像分割。这种模型利用 Mamba 的强大序列建模能力和线性复杂性,通过为 SSM 模块分配大内核来实现大感受野

  2. 设计了新颖的层次化和双向大型内核Mamba块(LM块):这种设计增强了 SSM 在视觉输入中的全局和邻域空间建模能力。特别是,LM 块包含**像素级 SSM (PiM)块级 SSM (PaM)**,这两者共同工作以增强局部邻域像素级和长距离全局块级建模。

  3. 引入了双向 Mamba (BiM):与原始的单向 Mamba 相比,作者提出的双向 Mamba 结构通过同时进行前向和后向扫描并叠加输出结果,改进了位置感知序列建模。这种双向设计使模型能够更好地关注图像中心区域的信息块,而不是仅仅关注角落区域,并且能够更好地建模每个块的绝对位置信息和与其他块的相对位置信息。

摘要

在临床实践中,医学图像分割提供了有关目标器官或组织的轮廓和尺寸的有用信息,有助于改善诊断、分析和治疗。在过去几年中,卷积神经网络(CNN)和 Transformers 在这一领域占据主导地位,但它们仍然受到有限的感受野或昂贵的长距离建模的限制。Mamba,一种状态空间序列模型(SSM),最近作为长距离依赖建模的有前途的范式出现,具有线性复杂性。在本文中,我们介绍了一种大型内核视觉 Mamba U 形网络,或称为 LKM-UNet,用于医学图像分割。我们 LKM-UNet 的一个区别特征是其使用大型 Mamba 内核,与基于小内核的 CNN 和 Transformers 相比,在局部空间建模方面表现出色,同时保持了与二次复杂性的自注意力相比的优越效率。此外,我们设计了一种新颖的层次化和双向 Mamba 块,以进一步增强 Mamba 对视觉输入的全局和邻域空间建模能力。综合实验表明,使用大尺寸 Mamba 内核实现大感受野是可行且有效的。

关键词

医学图像分割 · UNet · Mamba

0ec63215f117c9638a86f12a1fd0980d.jpeg

方法

在本节中,我们首先介绍 LKM-UNet 的整体架构。随后,我们详细阐述了核心组件,LM 块。

3.1 LKM-UNet

LKM-UNet 的概述如图 1 所示。具体来说,除了常见的 UNet 组成,包括深度可分离卷积、编码器下采样层、解码器上采样层和跳跃连接外,LKM-UNet 通过在编码器中插入提出的大内核 Mamba(LM)块来改进 UNet 的结构。给定一个分辨率为 C × D × H × W 的 3D 输入图像,深度可分离卷积首先将输入编码成特征图 F0 ∈ R48× D/2 × H/2 × W/2。然后特征图 F0 被送入每个 LM 块和相应的下采样层,并获得多尺度特征图。一个LM 块包含两个 Mamba 模块:像素级 SSM(PiM)和块级 SSM(PaM)。对于第 l 层,过程可以公式化为:

其中 PiM 和 PaM 分别表示像素级 SSM 和块级 SSM。Down-sampling 表示下采样层。在每个阶段之后,生成的特征图 被编码为 ,其中 表示特征图 的通道和分辨率。至于解码器部分,我们采用了 UNet 解码器和带有跳跃连接的残差块进行上采样和预测最终的分割掩码。

33226ab4ef75b8ba60bc110f40571746.jpeg

3.2 LM 块

LM 块是用于进一步空间建模每个阶段不同尺度的特征图的核心组件。与之前使用 CNN 进行局部像素级建模和 Transformer 进行长距离块级依赖性建模的方法不同,LM 块可以同时完成像素级和块级建模,得益于 Mamba 的线性复杂性。更重要的是,较低的复杂性允许设置更大的内核(窗口)以获得更大的感受野,这将提高局部建模的效率,如图 2(a) 所示。具体来说,LM 块是一个层次化设计,由像素级 SSM(PiM)和块级 SSM(PaM)组成;前者用于局部邻域像素建模,后者用于全局长距离依赖性建模。此外,LM 块中的每个 Mamba 层都是双向的,这是为了位置感知序列建模而提出的。像素级 SSM(PiM)。由于 Mamba 是一个连续模型,输入像素的离散性质可能会削弱局部相邻像素之间的相关性建模。因此,我们提出了一个像素级 SSM,将特征图分割成多个大子内核(子窗口)并对子内核进行 SSM 操作。我们首先将整个特征图等分为 2D 的非重叠子内核或 3D 的子立方体。以 2D 为例,给定一个 H × W 分辨率的输入,我们将特征图分割成大小为 m×n 的子内核(m 和 n 可以高达 40)。不失一般性,我们假设 H/m 和 W/n 都是整数。然后我们有 HW/mn 个子内核,如图 1 中的像素级 SSM 所示。在这种方案下,当这些子内核被送入 Mamba 层时,局部相邻像素将连续输入到 SSM;因此,局部邻域像素之间的关系可以更好地建模。此外,在大内核分割策略下,感受野被扩大,模型可以获得更多局部像素的细节。然而,图像被分割成非重叠的子内核。因此,我们需要一个机制来实现不同子内核之间的通信,以进行长距离依赖性建模。块级 SSM(PaM)。我们引入了一个块级 SSM 层,以在不同的子内核之间传递信息。如图 1 中的块级 SSM 所示,一个分辨率为 H × W 的特征图 首先通过一个大小为 m × n 的池化层,允许每个 HW/mn 个子内核的重要信息被总结成一个代表。因此,我们获得了具有 HW/mn 个代表的聚合图 ,然后这些聚合图被用来通过 Mamba 在子内核之间进行全局范围的依赖性建模。在 Mamba 中的交互之后,我们将聚合图反池化回与初始特征图 相同的大小,并应用残差连接。 在方程(5)中的过程可以执行为:

其中 Pooling 和 Unpooling 分别表示池化层和反池化层。Bi-Mamba 表示提出的双向 Mamba 层。双向 Mamba(BiM)。与基于前向扫描方向 SSM 层的原始 Mamba 块不同,LM 块中的每个 SSM 层(包括 PiM 和 PaM)都是双向的。图 2(b) 显示了差异。在原始 Mamba 中,作为一个连续模型,一些信息遗忘发生在早期进入的元素上,而最后进入 Mamba 的元素将保留更多的信息。因此,具有单一扫描方向的原始 Mamba 将更多地关注后部块,而不是通常包含更多器官和病变的特征图的中心区域。为此,我们提出了一个双向 Mamba 结构,通过同时进行前向和后向扫描并叠加输出结果。图 1 的左部分显示了详细结构。BiM 有两个优点。首先,模型可以更多地关注图像中心区域的信息块,而不是角落区域。其次,对于每个块,绝对位置信息和与其他块的相对位置信息可以被网络很好地建模。

实验

cd2ab089c4200a3d861a86bb9f1fe093.jpeg6b2411ba7c567f179fe058a979afb383.jpeg

声明

本文内容为论文学习收获分享,受限于知识能力,本文对原文的理解可能存在偏差,最终内容以原论文为准。本文信息旨在传播和学术交流,其内容由作者负责,不代表本号观点。文中作品文字、图片等如涉及内容、版权和其他问题,请及时与我们联系,我们将在第一时间回复并处理。

2a72c2a03ff4b7faf098aced652457a0.png

<think>好的,我需要回答用户关于UNet3D模型的使用教程、架构详解和代码实现的问题。首先,我需要回顾一下UNet的基本结构,因为UNet3D是其三维扩展。根据引用[3],U-Net有Encoder-Decoder结构,包含下采样和上采样路径,用于分割UNet3D应该是在此基础上将二维操作扩展到三维,比如用3D卷积代替2D卷积。 接下来,架构详解部分需要详细说明UNet3D的各组成部分。比如编码器部分可能由多个3D卷积层和下采样层(如3D MaxPooling)组成,每个阶段提取更高层次的特征。跳跃连接(skip connections)在U-Net中很重要,用于将编码器的特征与解码器的上采样结果拼接,帮助恢复空间信息。解码器部分可能使用转置卷积或上采样加卷积来进行上采样。 然后,关于代码实现,用户可能需要一个示例。我需要考虑使用深度学习框架如PyTorch或TensorFlow。假设使用PyTorch,可以定义3D卷积层、下采样和上采样模块。编码器块可能包含两个3D卷积层,接着是激活函数和池化。解码器块则包括上采样和特征拼接,然后是卷积层。最后输出层使用1x1x1卷积和sigmoid激活进行分割。 使用教程方面,需要涵盖数据准备(如三维医学图像的预处理)、模型初始化、训练流程(损失函数如Dice Loss,优化器如Adam)、评估指标(如Dice系数)。数据预处理可能需要归一化或标准化,数据增强如旋转、翻转等。训练循环部分要说明如何迭代数据、计算损失、反向传播等。 同时,用户提供的引用中提到WET-UNetLKM-UNet,特别是状态空间建模和Mamba Block,这可能与UNet的变体有关,但用户的问题集中在UNet3D,所以可能不需要深入这些变体,但如果有相关优化方法可以提及。 需要确保回答结构清晰,分点说明教程、架构和代码。最后生成相关问题,比如如何扩展到3D、优化方法、应用场景等。同时检查引用,比如数据预处理部分引用[1],架构部分引用[3],确保正确标注来源。</think>### UNet3D 模型解析与使用指南 #### 一、UNet3D 架构详解 UNet3D 是经典 U-Net 架构的三维扩展,核心结构由**编码器(收缩路径)**、**解码器(扩展路径)**和**跳跃连接**组成,专为处理三维体数据(如学影像 CT、MRI)设计[^3]。 1. **编码器(下采样路径)** - 使用多个 **3D 卷积层**(如 `Conv3d`)提取空间特征,每级后接 **3D 最大池化层**(如 `MaxPool3d`)降低分辨率。 - 示例模块: $$ \text{Input} \rightarrow \text{Conv3d}(k=3) \rightarrow \text{ReLU} \rightarrow \text{Conv3d}(k=3) \rightarrow \text{ReLU} \rightarrow \text{MaxPool3d}(s=2) $$ 2. **解码器(上采样路径)** - 通过**转置卷积**或**插值上采样**(如 `ConvTranspose3d`)恢复空间分辨率。 - 跳跃连接将编码器特征与解码器特征**拼接(Concatenate)**,增强细节保留能力。 - 示例模块: $$ \text{UpSampled} \rightarrow \text{Concatenate(skip)} \rightarrow \text{Conv3d}(k=3) \rightarrow \text{ReLU} \rightarrow \text{Conv3d}(k=3) \rightarrow \text{ReLU} $$ 3. **跳跃连接** - 直接传递编码器的高分辨率局部特征到解码器,解决三维数据中长程依赖性问题[^2]。 #### 二、代码实现(PyTorch 示例) ```python import torch import torch.nn as nn class DoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.conv = nn.Sequential( nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm3d(out_ch), nn.ReLU(inplace=True), nn.Conv3d(out_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm3d(out_ch), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class UNet3D(nn.Module): def __init__(self, in_ch=1, out_ch=2): super().__init__() # 编码器 self.enc1 = DoubleConv(in_ch, 64) self.pool1 = nn.MaxPool3d(2) self.enc2 = DoubleConv(64, 128) self.pool2 = nn.MaxPool3d(2) # 解码器 self.up1 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2) self.dec1 = DoubleConv(128, 64) # 输入通道数为跳跃连接拼接后的值 # 输出层 self.out_conv = nn.Conv3d(64, out_ch, kernel_size=1) def forward(self, x): # 编码 e1 = self.enc1(x) e2 = self.enc2(self.pool1(e1)) # 解码 d1 = self.up1(e2) d1 = torch.cat([e1, d1], dim=1) # 跳跃连接 d1 = self.dec1(d1) return self.out_conv(d1) ``` #### 三、使用教程 1. **数据准备** - 输入数据格式:`(Batch, Channel, Depth, Height, Width)` - 预处理:标准化(如 `(x - mean)/std`)或归一化到 $[0,1]$[^1]。 2. **模型训练** ```python model = UNet3D(in_ch=1, out_ch=2) criterion = nn.CrossEntropyLoss() # 或 DiceLoss optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) for epoch in range(100): for x, y in dataloader: pred = model(x) loss = criterion(pred, y) optimizer.zero_grad() loss.backward() optimizer.step() ``` 3. **应用场景** - 学影像分割(肿瘤、器官) - 工业三维缺陷检测 - 遥感体数据解析
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值