https://github.com/JCruan519/VM-UNet
VM-UNet
摘要
CNN在长距离建模方面存在局限性。基于CNN的模型受限于其局部感受野,这大大阻碍了它们捕获长距离信息的能力。这通常导致提取的特征不足,从而导致分割结果不理想。
Transformer由于其二项式计算复杂度而受到限制。尽管基于Transformer的模型在全局建模方面表现出色,但自注意力机制对图像大小的计算复杂度要求是二次的,导致计算负担高,
状态空间模型(SSMs),如Mamba,已经成为一种有前景的方法。它们不仅擅长建模长距离交互,而且保持了线性计算复杂度。
引言
UNet+Transformer
TransUnet是Transformer模型的先驱之一,首次在编码阶段使用视觉Transformer(ViT)进行特征提取,并在解码阶段使用CNN,展示了其获取全局信息的重要能力。
SwinUNet编码解码阶段都使用了SwinTransformer
状态空间模型SSMs
现代SSM(eg:Mamba)不仅建立了长距离依赖性,而且相对于输入大小表现出线性复杂度。
U-Mamba最近引入了一种新的SSM-CNN混合模型,标志着其在医学图像分割任务中的首次应用。
SegMamba整合了SSM到编码器部分,同时仍然在解码器部分使用CNN。
本文首次介绍了Vision Mamba UNet(VM-UNet),这是一个为展示在医学图像分割任务中潜力而设计的纯SSM-based模型。
VM-UNet由编码器、解码器和跳跃连接三部分组成。编码器由VMamba的特征提取VSS块组成,以及用于下采样的patch合并操作。相反,解码器由VSS块和用于恢复分割结果大小的patch扩展操作组成。对于跳跃连接部分,为了突出最原始纯SSM-based模型的分割性能,我们采用了最简单的加法操作形式。
预备知识
在现代基于状态空间模型(SSM)的模型中,例如Structured State Space Sequence Models(S4)和Mamba,都依赖于将一维输入函数或序列通过中间隐状态映射到输出的经典的连续系统。该过程可以表示为线性常微分方程(ODE):
方法
Vision Mamba UNet (VM-UNet)
如图1(a)所示,VM-UNet的整体架构包括一个Patch Embedding层、一个编码器、一个解码器、一个Final Projection层和跳跃连接。与以往的方法不同,我们没有采用对称结构,而是采用不对称设计。Patch Embedding层将输入图像 x∈H×W×3 分割成不重叠的4×4大小的patches,然后将图像的维度映射到 C,其中 C 默认为96。这个过程导致嵌入后的图像 x′∈H/4×W/4×C。最后,使用层归一化(Layer Normalization)对 x′ 进行归一化,然后输入到编码器进行特征提取。
编码器由四个阶段组成,在前三个阶段的末尾应用patch Merging操作,以减少输入特征的H高度和W宽度,同时增加C通道数。我们在四个阶段中分别使用了2、2、2、2个VSS块,每个阶段的通道数分别为C,2C,4C,8C。
解码器也分为四个阶段。在最后三个阶段的开始,使用patch Expanding操作来减少特征C通道数并增加H高度和W宽度。在四个阶段中,我们使用了2、2、2、1个VSS块,每个阶段的通道数分别为8C,4C,2C,C。解码器之后是一个Final Projection层,用于将特征的大小恢复到与分割目标匹配。具体来说,通过patch Expanding进行4倍上采样,以恢复特征的高度和宽度,然后通过一个投影层来恢复通道数。对于跳跃连接,采用了简单的加法操作,没有额外的复杂性,因此没有引入任何额外的参数。
VSS block
VSS块源自VMamaba,是VM-UNet的核心模块,如图1(b)所示。经过Layer Norm层归一化后,输入被分成两个分支。
第一个分支,输入--->线性层--->激活函数。
第二个分支,输入--->线性层--->深度可分离卷积--->激活函数--->2D-Selective-Scan (SS2D)模块进行进一步的特征提取。--->Layer Norm对特征进行归一化,×第一个分支的输出进行逐元素乘法,以合并两个路径。最后,通过一个线性层混合特征,并通过与残差连接的组合形成VSS块的输出。在本文中,默认使用SiLU作为激活函数。
SS2D由三个部分组成:扫描扩展操作、S6块和扫描合并操作。如图2(a)所示
扫描扩展操作将输入图像沿四个不同方向(从左上到右下、从右下到左上、从右上到左下、从左下到右上)展开成序列。然后,由S6块处理这些序列以提取特征,确保从不同方向扫描的信息被彻底扫描,从而捕获不同的特征。随后,如图2(b)所示,扫描合并操作将来自四个方向的序列相加并合并,将输出图像恢复到与输入相同的大小。
S6块源自Mamba,通过根据输入调整SSM的参数,在S4的基础上引入了选择性机制。这使得模型能够区分并保留相关信息,同时过滤掉不相关的信息。S6块的伪代码在算法1中给出。
结果:
---------------------------------------------------------------------------------------------------------------------------------
VM-UNet-V2
与VM-UNet的主要区别,就是跳跃连接变成了SDI,用的CBAM。
SDI
语义和细节注入(SDI)模块是 VM-UNetV2 模型中用于增强特征融合的关键组件。以下是对SDI模块更详细的解释:
功能和目的
SDI模块的目的是进一步提升模型在医学图像分割中的表现,通过更有效地融合不同层次的特征(即,低级和高级特征)。这种融合有助于模型同时捕获图像的局部细节和全局上下文信息。
工作流程
-
特征生成 f0(源自编码器):
- 在编码器部分处理完输入图像后,会产生一系列层次化的特征图(例如,f0i)。
-
注意力机制 f1(源自f0+SDI注意力):
- SDI模块采用注意力机制(如CBAM,即卷积块注意力模块)来计算空间和通道的注意力分数。这有助于模型识别特征图中最重要的部分,从而在融合特征时给予这些部分更多的权重。
-
特征调整 f2(源自f1+1*1卷积调整通道数):
- 通过注意力机制处理后的特征图 f1i会经过1×1的卷积操作来调整通道数,以匹配目标输出的通道数(记为 f2i)。
-
多尺度特征融合 f3(源自编码器):
- SDI模块会处理来自不同编码器阶段的特征图,调整它们的大小以匹配目标特征图 f2i的尺寸。这一步骤涉及使用不同的操作,如自适应平均池化、恒等映射和双线性插值,来确保特征图在空间尺寸上的一致性。
-
特征融合 f4=f2 + f3:
- 调整后的特征图 f3i会与目标特征图 f2i进行融合。这种融合可以通过加权求和、拼接或其他合并操作完成,目的是将不同尺度的特征信息有效地结合起来。
-
输出特征 f5=f4*H:
- 融合后的特征图 f5i会被送入解码器部分,用于进一步的分辨率重建和最终的分割图生成。
技术细节
- 空间和通道注意力:通过CBAM实现,模型能够关注特征图中最相关的区域和通道。
- 特征融合策略:SDI模块采用复杂的融合策略,确保不同尺度和抽象级别的特征得到有效整合。
- 多尺度处理:通过调整不同阶段输出的特征图大小,SDI模块能够处理多尺度的特征,这对于医学图像分割尤其重要。
优点
- 增强的上下文感知:通过融合不同层次的特征,模型能够更好地理解图像的全局上下文。
- 改善的细节捕获:注意力机制帮助模型集中于图像中的关键细节,从而提高分割精度。
- 灵活性和适应性:SDI模块的设计使其能够适应不同的数据集和分割任务。