LightM-UNet:一个基于Mamba的轻量级U形分割模型

UNet作为一种基于卷积神经网络(CNN-based)的模型,面临卷积操作固有局部性,限制了它理解明确的全局和长距离语义信息交互的能力。一些研究尝试通过扩张卷积层,自我注意力机制和图像金字塔来缓解这个问题但在建模长距离依赖方面仍表现出限制。且近期研究探讨了整合Transformer架构,利用自注意力机制将图像视为一系列连续的Patch来捕获全局信息,但由于子注意力机制导致了图像尺寸的二次复杂度,特别对于需要密集预测的任务(如医疗图像分割),带来了相当大计算开销。所以仍然存在一个问题:“如何在不增加额外参数和计算负担的情况下,赋予UNet容纳长距离依赖的能力?”

近期,由于状态空间模型(SSMs)不仅能建立长距离依赖关系,而且还有输入规模的线性复杂性。于是一些尝试如U-Mamba【14】,提出了一个混合的CNN-SSM块,结合了卷积层的局部特征提取能力与SSM在捕捉纵向依赖关系方面的专长但其引入了大量参数和计算负载,带来了相当大计算开销。在这项研究,引入了LightM-UNet,一个基于MAmba的轻量级U形分割模型,显著降低参数和计算成本的同时实现了最先进的性能。(如下图)

LightM-UNet是UNet与Mamba的轻量级融合,超越了现有的最先进模型。在技术方面,提出了残差视觉曼巴层(RVM层)”以纯曼巴方式从图像中提取深层特征。在引入的新参数和计算开销最小的情况下,作者通过使用“残差连接”和“调整因子”进一步增强了SSM对视觉图像中长距离空间依赖关系建模的能力。

方法论

LightM-UNet的总体架构如图2所示。

给定一个输入图像 ,其中C 、H、W和D分别表示3D医疗图像的通道数、高度、宽度和切片数。LightM-UNet首先使用深度可分卷积(DWConv)层进行浅层特征提取,生成浅层特征图其中32表示固定的滤波器数量。随后,LightM-UNet结合三个连续的编码器块(Encoder Blocks)从图像中提取深层特征。在每个编码器块之后,特征图中的通道数翻倍,而分辨率减半。在-th编码器块处,LightM-UNet提取深层特征

LightM-UNet使用瓶颈块(Bottleneck Block)来建模长距离空间依赖关系,同时保持特征图的大小不变。之后,LightM-UNet整合三个连续的解码器块(Decoder Blocks)进行特征解码和图像分辨率恢复。在每个解码器块之后,特征图中的通道数减半,分辨率加倍。

最后,最后一个解码器块的输出达到与原始图像相同的分辨率,包含32个特征通道。LightM-UNet使用DWConv层将通道数映射到分割目标数,并应用SoftMax激活函数生成图像 Mask 。与UNet的设计一致,LightM-UNet也采用跳跃连接(skip connections)为解码器提供多 Level 特征图。

Encoder Block

为了最小化参数数量和计算成本,LightM-UNet采用了仅包含Mamba结构的编码器块来从图像中提取深层特征。

给定一个特征图编码器块首先将特征图展平并转置成的形状,其中

随后,编码器块使用个连续的RVM层来捕捉全局信息,在最后一个RVM层中通道数增加。此后,编码器块重新调整并转置特征图的形状为紧接着进行最大池化操作以降低特征图的分辨率。最终,第个编码器块输出新的特征图,其形状为

Residual Vision Mamba Layer (RVM Layer)

LightM-UNet提出了RVM层以增强原始的SSM块,用于图像深层语义特征提取。具体来说,LightM-UNet利用先进的残差连接和调整因子进一步增强了SSM的长距离空间建模能力,几乎不引入新的参数和计算复杂性。

对于给定的输入深层特征,RVM层首先采用LayerNorm,然后是VSSM来捕捉空间长距离依赖。随后,它在残差连接中使用调整因子以获得更好的性能。这个过程可以用以下数学方式表示:

紧随其后,RVM层使用另一个LayerNorm来规范化,随后利用一个投射层将转换为一个更深的特征。上述过程可以表述为:

视觉状态空间模块(VSS模块)遵循[13]中概述的方法,LightM-UNet引入了VSS模块(如图2(b)所示)进行长距离空间建模。VSS模块以特征作为输入,并将其引导到两个并行分支中。在第一个分支中,VSS模块使用线性层将特征通道扩展到其中表示预定义的通道扩展因子。

随后,它应用了DWConv、SiLU激活函数[20],然后是SSM和层归一化。在第二个分支中,VSS模块同样使用线性层将特征通道扩展到之后是SiLU激活函数。随后,VSS模块通过哈达玛积从两个分支聚合特征,并将通道数投射回,生成与输入形状相同的输出。上述过程可以公式化为:

其中表示哈达玛积。

Bottleneck Block

类似于Transformer,当网络深度变得过大时,Mamba也会遇到收敛挑战。因此,LightM-UNet通过结合四个连续的RVM层来构建瓶颈,以进一步建模空间长期依赖关系,从而解决这个问题。在这些瓶颈区域中,特征通道的数量和分辨率保持不变。

Decoder Block

LightM-UNet使用了解码器块(Decoder Blocks)来解码特征图并恢复图像分辨率。具体来说,给定来自跳跃连接的和来自前一个块的输出的解码器块首先通过加法操作进行特征融合。

随后,它使用一个深度卷积(DWConv)、一个残差连接以及一个ReLU激活函数来解码特征图。另外,一个调整因子被加到残差连接上以增强解码能力。这个过程可以用数学方式表达为:

解码器块最终采用双线性插值方法将预测恢复到原始分辨率。

Experiments

选择了两个公开可用的医学图像数据集:LiTs数据集[1],包含3D CT图像;以及Montgomery&Shenzhen数据集[9],包含2D X光图像。这些数据集本研究中分别用来验证2D和3D版本的LightM-UNet的性能。数据被随机划分为训练集、验证集和测试集,比例分别为7:1:2。

LIghtM-UNet使用PyTorch框架实现,三个编码器块中RVM层数分别设置1、2和2。在单个Quadro RTX 8000 GPU上进行,采用SGD作为优化器,初始学习率1e-4.PolyRScheduler作为调度器,训练了100个周期。

损失函数被设计为交叉熵损失和Dice损失的简单组合。

LiTs数据集:图像被归一化并调整至128*128*128的大小,批量大小为2.

Montgomery&Shenzhen数据集:图像被归一化并调整至512*512的大小,批量大小为12。

为了评估LightM-UNet,作者将其与两种基于CNN的分割网络(nnU-Net和 SegResNet)、两种基于Transformer的网络(UNETR和 SwinUNETR)以及一种基于Mamba的网络(U-Mamba)进行了比较,同时还采用平均交并比(mloU)和Dice相似度得分(DSC)作为评估指标。

由对比实验结果,作者的LightM-UNet就算和nnU-Net这种大型模型相比表现出更优越的性能,也显著减少了参数数量和计算成本,分别降低了47.39*和15.82*。

同期的U-Mamba相比,LightM-UNet在平均mIoU方面提高了2.11%的性能。特别是对于通常太小而难以轻易检测的肿瘤,LightM-UNet在mIoU上实现了3.63%的改进。重要的是,作为一种将Mamba方法融入UNet架构的方法,与U-Mamba相比,LightM-UNet仅使用了少1.07%的参数和2.53%的计算资源。

展示了分割结果示例,表明与其他模型相比,LightM-UNet具有更平滑的分割边缘,并且不会对小型目标(如肿瘤)产生错误的识别。

作者进一步去除了RVM层中的调整因子和残差连接。实验结果表明,在移除这两个组件后,模型的参数数量和计算开销几乎没有减少,但模型的性能显著下降(mIoU下降了0.44%和0.69%)。这验证了作者在不引入额外参数和计算开销的情况下提升模型性能的基本原则。

  • 13
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值