一、算法结构详解:nnUNet-Mamba架构
我们的算法将传统医学分割模型 nnUNet
与新型序列建模模块 Mamba
结合,形成独特的"编码-解码+状态空间建模"混合架构。
核心设计目标:在保持医学影像分割精度的同时,利用状态空间模型的线性复杂度优势与硬件感知优化,实现在低算力平台(如移动端、边缘设备)的高效部署。
关键创新点:在每级下采样后插入 Mamba模块(黄色块),与 卷积层 交替作用,这将替代传统 Transformer
中的 自注意力机制
核心优化逻辑总结
- 输入阶段:通过 中心裁剪 + 动态注意力,聚焦
ROI
区域,减少无效背景计算 - 编码阶段:
Mamba
模块 替代自注意力,线性复杂度从根本上降低算力需求 - 模型设计:残差连接 + 轻量化卷积,在精度损失 <
1%
的前提下,参数减少30%
- 硬件适配:
Parallel Scan
+ 定点运算优化,实现 低配置设备 的实时推理
1. 输入预处理阶段
• 输入图像尺寸:572
× 572
高分辨率医学图像(如CT/MRI
)
• 特征初提取:经3×3卷积(输出尺寸570
× 570
,通道数默认 64
)+ ReLU
激活 (较传统 UNet
减少 50%
初始通道)
• 多尺度特征保留:执行 「复制与裁剪」(Copy and crop
)操作,将 570
× 570
特征图复制为双路,进行 双路特征分流
• 主路径:进行深度特征提取,继续卷积压将图像尺寸缩至 568
× 568
• 跳跃连接:裁剪为392
× 392
保留 中尺度特征。跳跃连接的特征图通过 通道注意力门控 动态选择有用特征:
α = σ ( W ⋅ [ F low , F high ] ) F fused = α ⊙ F low + ( 1 − α ) ⊙ F high \begin{aligned} \alpha &= \sigma\left( W \cdot [ F_{\text{low}}, F_{\text{high}} ] \right) \\ F_{\text{fused}} &= \alpha \odot F_{\text{low}} + (1 - \alpha) \odot F_{\text{high}} \end{aligned} αFfused=σ(W⋅[Flow,Fhigh])=α⊙Flow+(1−α)⊙Fhigh
其中 F low 为深层特征, F high 为浅层特征, σ 为 sigmoid 函数 \begin{aligned} & \text{其中 } F_{\text{low}} \text{ 为深层特征,} F_{\text{high}} \text{ 为浅层特征,} \sigma \text{ 为 sigmoid 函数} & \end{aligned} 其中 Flow 为深层特征,Fhigh 为浅层特征,σ 为 sigmoid 函数
我们还采用了 动态裁剪策略,使用 渐进式的中心裁剪,以图像中心为锚点,按比例缩小保留区域(如从570
×570
裁剪到 392
×392
时,保留中心 68.8%
区域)。这种仅对中尺度特征(392×392)进行 通道加权 的方法,较全尺寸注意力(Swim-UNet
)减少 70% 计算量
2. 多级下采样编码( nnUNet
部分)
我们对 nnUNet
网络进行了的轻量化改进:通过 Mamba
模块替代高耗 Transformer
,构建 「卷积 - 池化 - Mamba 高效组合,算力需求降低 60%。同时我们保留了经典 UNet
式 收缩路径,通过 卷积池化堆叠 逐步提取深层特征。
下采样阶段 | 操作序列 | 输出尺寸 | Mamba参数配置 |
---|---|---|---|
Stage 1 | Conv3×3 → ReLU → MaxPool2×2 | 284 ×284 | Hidden state dimension=256 |
Stage 2 | Conv3×3 → Conv3×3 → MaxPool2×2 | 104 ×104 | Hidden state dimension=512 |
Stage 3 | Conv3×3 → Mamba Block → MaxPool2×2 | 32 ×32 | Hidden state dimension=1024 |
计算效率优化点:
- 首次池化后特征减半,显存占用 ↓
50%
- 双卷积层参数减少
20%
,引入残差连接 - Mamba 块替代自注意力,计算量从 O (N²) 降至 O (N)
Mamba
架构插入策略(算力优化核心):
- “三明治结构”: 在每级池化后插入,形成 Conv→Pool→Mamba 的 “三明治结构”。先降采样再建模,特征序列长度
N
减少50%
后进入Mamba
,计算量进一步降低 - 补丁合并优化: 每次池化将特征图划分为
2×2
补丁,通道数翻倍但 分辨率减半,显存占用仅为Transformer
的 35% - 残差连接轻量化: 采用 残差连接 缓解梯度消失:Output = MambaBlock(x) + x
3. Mamba
模块:低算力平台的核心引擎
(详情参考我们的另一篇介绍文档Mamba
的内容,在这里附上另一篇文章的链接: Mamba块讲解)
Mamba
是一种基于 **结构化状态空间模型(SSM
, State Space Model
)**的序列建模架构,具有 三大硬件友好特性
对比传统 Transformer
框架,其核心突破在于以下3个方面:
-
选择性机制:动态调整对输入特征的关注权重(传统
SSM
的固定参数被替代) -
线性复杂度建模:
- Δ = τ ⋅ softplus ( W Δ x t ) 采用动态时间步长,避免冗余计算 \Delta = \tau \cdot \text{softplus}(W_{\Delta}x_t) \quad \text{采用动态时间步长,避免冗余计算} Δ=τ⋅softplus(WΔxt)采用动态时间步长,避免冗余计算
-
显存优化技术: 进行
Parallel Scan
并行扫描,有效地将GPU
显存占用降至 2.1GB(相当于RTX 3090
处理512
×512
图像),较Transformer
的6.8GB
降低 65% 。还通过Cross-Scan Module
直接利用图像本身像素的空间顺序,减少了额外的参数(如ViT
的positional embedding
占10%
参数) -
连续-离散混合建模:将连续时间微分方程与离散序列处理无缝衔接
4. 算法数学模型
连续时间状态空间方程(捕获特征演化连续性):
d
h
(
t
)
d
t
=
A
(
t
)
h
(
t
)
+
B
(
t
)
x
(
t
)
\frac{dh(t)}{dt} = A(t)h(t) + B(t)x(t)
dtdh(t)=A(t)h(t)+B(t)x(t)
y ( t ) = C ( t ) h ( t ) + D ( t ) x ( t ) y(t) = C(t)h(t) + D(t)x(t) y(t)=C(t)h(t)+D(t)x(t)
- A(t): 状态转移矩阵,控制隐状态 h(t) 的自演化规律
- B(t): 输入投影矩阵,决定当前输入 x(t) 对状态的影响强度
- C(t): 输出投影矩阵,提取状态中有用的特征成分
- D(t): 跳跃连接矩阵,保留原始输入的高频细节
状态空间方程离散化过程:
Δ
=
τ
⋅
softplus
(
W
Δ
x
t
)
A
ˉ
=
exp
(
Δ
A
)
≈
I
+
Δ
A
一阶泰勒展开,在算力不足时进行近似计算
B
ˉ
=
(
Δ
A
)
−
1
(
exp
(
Δ
A
)
−
I
)
Δ
B
≈
Δ
B
这是简化版的输出投影
\begin{aligned} \Delta &= \tau \cdot \text{softplus}(W_{\Delta}x_t) \\ \bar{A} &= \exp(\Delta A) \approx I+\Delta A \text{ 一阶泰勒展开,在算力不足时进行近似计算}\\ \bar{B} &= (\Delta A)^{-1}(\exp(\Delta A) - I)\Delta B \approx \Delta B\text{ 这是简化版的输出投影} \end{aligned}
ΔAˉBˉ=τ⋅softplus(WΔxt)=exp(ΔA)≈I+ΔA 一阶泰勒展开,在算力不足时进行近似计算=(ΔA)−1(exp(ΔA)−I)ΔB≈ΔB 这是简化版的输出投影
其中 (τ
) 为时间步长系数,通过学习可以进行动态调整。本模型适合建模医学影像中的 渐进性病理改变(如肿瘤生长、组织纤维化)
维度 | Transformer | Mamba |
---|---|---|
长序列建模能力 | O(N²) 注意力计算 | O(N) 线性复杂度 |
硬件效率 | 高显存占用 | 显存占用降低65% |
动态感受野 | 固定窗口注意力 | 全局自适应感知 |
医学图像适用性 | 适合小尺寸图像 | 支持 1024 ×1024 |
二、算法核心优越性分析(聚焦低算力优势与临床价值)
1. 全局 + 局部特征协同建模能力突破
传统 nnUNet
依赖局部卷积堆叠捕捉全局信息,导致大器官分割时 边界模糊;而 Transformer
类模型虽能建模长程依赖,却因高算力消耗难以获取硬件支持。 nnUNet-Mamba
通过 Mamba
模块与卷积层的协同设计,实现 「局部细节保留 + 全局语义对齐」 的双重优势:
维度 | 传统 nnUNet | TransUNet /Swin-UNet | nnUNet-Mamba |
---|---|---|---|
特征交互范围 | 局部 3 × 3 卷积(感受野有限) | 窗口化自注意力(固定范围) | 全局线性建模 + 动态通道注意力 |
长程依赖效率 | 需 10 + 层堆叠(计算冗余) | O (N²) 二次复杂度,要求高算力 | 单层 Mamba 实现全图建模(O (N)) |
典型场景表现 | 小器官(如胰腺)分割尚可 | 中等尺寸的器官计算耗时 | 大器官 / 多发病灶(如全肝、腹部多器官)分割精度提升 3%-5% |
在 Synapse
腹部 CT 多器官分割数据集 中,图中示例显示nnUNet-Mamba
对肝这类大器官的 Dice
系数分别达 0.965
、0.948
,较传统 nnUNet
框架提升 4.6%
、2.4%
2. 计算效率突破:从「高算力依赖」到「端侧实时推理」
我们的算法通过 线性复杂度建模 + 硬件感知优化,突破传统 Transformer
的算力瓶颈,成为 低算力平台 的理想选择
训练与测试数据: LiTS - Liver Tumor Segmentation Challenge 数据集
该数据集收集了 7 个不同医学中心的数据,包含 131
例训练集 和 70
例测试集,其中 测试数据标签不公开,文件格式为 .nii
我们将 nnUNet-Mamba
模型与 Transformer
类模型 Swin-UNet
做性能对比(512
× 512
图像,RTX 2060
实测数据):
指标 | Transformer 类(Swin-UNet ) | nnUNet-Mamba | 技术突破点 |
---|---|---|---|
计算复杂度 | O(N²) (二次方 复杂度) | O(N) (线性复杂度) | 建模长程依赖的计算复杂度从 二次方降至线性,避免传统 Transforme r 随图像尺寸增长的算力爆炸问题 |
显存占用 | 5.2GB (需 独立显卡 支持) | 2.3GB (4GB 显存设备稳定运行) | 通过 Parallel Scan 并行扫描技术,显存占用降低 56% ,支持中端显卡(如 RTX 2060 )及嵌入式 GPU 部署 |
推理耗时 | 45ms (仅 GPU 端优化) | 25ms (GPU /CPU 混合加速) | 在 RTX 2060 上实测速度提升 44% ,同等算力下推理效率显著优于自注意力机制 |
3. 基于小样本学习:数据效率的本质提升
因为医学影像标注成本高昂,我们的算法设计的初衷就是在如何在有限的数据下,让模型达到更好的性能。nnUNet-Mamba
凭借 状态空间模型的强归纳偏置,在小样本场景展现显著优势
训练与测试数据: 第九届全国大学生生物医学工程创新设计竞赛消化系统诊疗赛道数据集
该数据集包含 289
例训练集 和 50
例测试集,其中 测试数据标签不公开,文件格式为 .nii.gz
在仅有 289
例标注数据的 肝癌病灶分割任务 中(与经典模型 nnUNet
对比):
模型 | Dice | HD95↓(mm) | 关键差异点 |
---|---|---|---|
nnUNet | 0.723 | 4.57 | 依赖 局部卷积堆叠,难以对肿瘤的异质性生长模式进行建模 |
nnUNet-Mamba | 0.812 | 2.89 | Dice 提升 12.3% ,HD95 误差降低 36.8% ,肿瘤边缘分割更精准 |
Mamba
模块通过连续时间建模捕获可以捕获病灶生长规律特征,因此在一定程度上减少对数据量的依赖
4.小尺度病灶分割优势:从「漏诊盲区」到「精准捕捉」
在医学影像中,小尺度病灶(如早期肝癌微灶,<3mm
的肺磨玻璃结节)的分割是临床难点,是早期诊断的关键瓶颈。 传统模型常因 局部感受野不足(如卷积神经网络 CNN
) 或 自注意力背景淹没效应(小病灶像素占比 <0.1%
时被忽略) 导致病灶漏画或推理出病灶的边界模糊。我们的nnUNet-Mamba
通过全局上下文建模与多尺度细节保留,实现小病灶分割性能的提升。
Mamba
模块的 结构化状态空间模型(SSM) 通过 连续时间方程:
d
h
(
t
)
d
t
=
A
(
t
)
h
(
t
)
+
B
(
t
)
x
(
t
)
\frac{dh(t)}{dt} = A(t)h(t) + B(t)x(t)
dtdh(t)=A(t)h(t)+B(t)x(t)
这种将 全图像素序列视为动态演化的状态序列 的方法,即使是 5mm
的微小病灶(约 25
像素),也能通过状态转移矩阵 A (t) 捕捉其与周围组织的依赖关系。通过 **B (t) **输入投影矩阵增强病灶区域特征权重,抑制肝实质的噪声干扰
测试数据集:LIDC-IDRI(肺结节)+ 第九届全国大学生生物医学工程创新设计竞赛消化系统诊疗赛道(微小转移灶)
指标 | 传统 nnUNet | Swin-UNet | nnUNet-Mamba | 优势解析 |
---|---|---|---|---|
小病灶 Dice↑ | 0.682 | 0.715 | 0.794 | 全局上下文引导的边缘定位 |
边缘误差↓(mm) | 1.82 | 1.57 | 1.13 | 动态通道注意力精准对齐病灶与背景边界 |