文章目录
1 论文解读
论文标题:
VM-UNet: Vision Mamba UNet for Medical Image Segmentation
论文链接:
https://arxiv.org/abs/2402.02491
源码链接:
https://github.com/JCruan519/VM-UNet
1.1 引言
Abstract—In the realm of medical image segmentation, both CNN-based and Transformer-based models have been extensively explored. However, CNNs exhibit limitations in long-range modeling capabilities, whereas Transformers are hampered by
their quadratic computational complexity. Recently, State Space Models (SSMs), exemplified by Mamba, have emerged as a
promising approach. They not only excel in modeling long-range interactions but also maintain a linear computational complexity. In this paper, leveraging state space models, we propose a Ushape architecture model for medical image segmentation, named Vision Mamba UNet (VM-UNet). Specifically, the Visual State Space (VSS) block is introduced as the foundation block to capture extensive contextual information, and an asymmetrical encoder-decoder structure is constructed with fewer convolution layers to save calculation cost. We conduct comprehensive experiments on the ISIC17, ISIC18, and Synapse datasets, and the results indicate that VM-UNet performs competitively in medical image segmentation tasks. To our best knowledge, this is the first medical image segmentation model constructed based on the pure SSM-based model. We aim to establish a baseline and provide valuable insights for the future development of more efficient and effective SSM-based segmentation systems.
医学图像分割是计算机视觉领域的一个重要研究方向,旨在通过自动化技术帮助医生更快地进行病理诊断,从而提高患者护理的效率。近年来,基于卷积神经网络(CNN)和 Transformer 的模型在医学图像分割任务中表现出色。然而,CNN 模型在长距离建模能力上存在局限,而 Transformer 模型则因其二次计算复杂度而受到限制。最近,状态空间模型(SSMs),尤其是 Mamba
模型,因其在长距离建模和线性计算复杂度方面的优势,成为了一个备受关注的研究方向。
本文提出了一种基于状态空间模型的U形架构模型,称为 Vision Mamba UNet(VM-UNet)
,专门用于医学图像分割。VM-UNet
通过将 Mamba
的线性复杂度建模能力与 U-Net
的经典架构相结合,通过引入视觉状态空间(VSS)块来捕捉广泛的上下文信息,并采用非对称的编码器-解码器结构以减少计算成本。实验结果表明,VM-UNet
在多个医学图像分割数据集上表现出色,成为首个基于纯 SSM 模型的医学图像分割模型,在医学图像分割任务中实现了高效的长距离依赖捕获。
代码中创新的双注意力机制(SC_Att_Bridge
)和分块并行 Mamba 处理(PVMLayer
)是其性能优越的关键。
(Mamba
的内容在博主的另一篇文章【手撕代码】Mamba源码讲解)
1.2 VM-UNet的核心架构
VM-UNet
的整体架构如图所示,主要由以下几个部分组成:
-
Patch Embedding层:将输入图像划分为不重叠的 4x4 patches,并将其映射到指定的通道数(默认96)。
-
编码器(Encoder):由四个阶段组成,每个阶段包含多个
VSS
块,并在前三个阶段末尾进行patch merging
操作以降低特征图的分辨率并增加通道数。 -
解码器(Decoder):同样由四个阶段组成,每个阶段包含多个
VSS
块,并在后三个阶段开始处进行patch expanding
操作以恢复特征图的分辨率。 -
Final Projection层:通过
patch expanding
操作将特征图的分辨率恢复到原始大小,并通过投影层恢复通道数。 -
Skip Connections:采用简单的加法操作,不引入额外的参数。
1.3 VSS块:VM-UNet的核心模块
VSS 块
是 VM-UNet
的核心模块,其结构如图所示。VSS块
的主要操作包括:
-
Layer Normalization:对输入进行归一化处理。
-
双分支结构:
-
第一分支:通过线性层和激活函数(默认SiLU)处理输入。
-
第二分支:通过线性层、深度可分离卷积和激活函数处理输入,然后进入 2D-Selective-Scan(SS2D)模块进行特征提取。
-
-
SS2D模块:由扫描扩展操作、S6块和扫描合并操作组成。扫描扩展操作将输入图像沿四个不同方向展开为序列,S6块对这些序列进行特征提取,扫描合并操作将四个方向的序列合并为输出图像。
-
特征融合:将两个分支的输出进行逐元素相乘,并通过线性层混合特征,最后与残差连接结合形成 VSS 块的输出。
1.4 损失函数
VM-UNet 采用最基础的损失函数来验证纯 SSM 模型在医学图像分割任务中的应用潜力。对于二分类任务,使用 Binary Cross-Entropy
和 Dice
损失(BceDice损失);对于多分类任务,使用 Cross-Entropy
和 Dice
损失(CeDice损失)。
L B c e D i c e = λ 1 ⋅ L B c e + λ 2 ⋅ L D i c e L C e D i c e = λ 3 ⋅ L C e + λ 4 ⋅ L D i c e LBceDice = \lambda_1 \cdot LBce + \lambda_2 \cdot LDice \quad LCeDice = \lambda_3 \cdot LCe + \lambda_4 \cdot LDice LBceDice=λ1⋅LBce+λ2⋅LDiceLCeDice=λ3⋅LCe+λ4⋅LDice
1.5 实验与结果
VM-UNet 在多个医学图像分割数据集上进行了实验,包括 ISIC17、ISIC18 和 Synapse 数据集。实验结果表明,VM-UNet 在皮肤病变分割和多器官分割任务中表现出色,尤其是在 ISIC17 和 ISIC18 数据集上,VM-UNet 在多个评价指标上均优于现有的最先进模型。
1.6 消融实验
为了进一步验证 VM-UNet 的设计选择,作者进行了多项消融实验,包括:
-
初始化权重:使用
VMamba-S
预训练权重初始化的 VM-UNet 在 ISIC17 和 ISIC18 数据集上的表现显著优于未使用预训练权重的模型。 -
Dropout值:不同 Dropout 值对模型性能的影响在不同数据集上表现不同,ISIC17 数据集上无 Dropout 时表现最佳,而 ISIC18 数据集上 Dropout 值为 0.2 时表现最佳。
-
编码器-解码器架构:非对称结构的 VM-UNet 在参数数量和计算量上优于对称结构,且性能更佳。
-
输入分辨率:增加输入分辨率并未显著提升 VM-UNet 的性能,说明在视觉领域中,Mamba 模型对输入序列长度的敏感性需要进一步研究。
1.7 可视化结果展示
通过可视化 VM-UNet
在 ISIC18数据集
上的分割结果,可以看出 VM-UNet 在各种复杂场景下表现出色,尤其是在处理小目标和复杂边界时,VM-UNet 的分割效果显著优于其他模型。
1.8 网络不足之处
尽管 VM-UNet 在医学图像分割任务中表现出色,但仍存在一些局限性。例如,模型在处理超出训练序列长度的数据时可能存在泛化能力不足的问题,且 VM-UNet 的参数量较大。此外,VM-UNet在皮肤病变分割任务中对浅色区域和毛发干扰的敏感性也需要进一步研究。
2. VM-UNet 代码解析
VM-UNet
是基于 状态空间模型(SSM) 的医学图像分割模型,其核心创新在于结合 Mamba
模块的长序列建模能力和 U-Net 的经典编码器-解码器架构。我挑出源码中的主要部分,如不同的注意力机制来讲解,Mamba
部分可以参考博主的另一篇文章【手撕代码】Mamba源码讲解
需要全部代码的的小伙伴可私信作者,或点击此处获取:VM-UNet
2.1 __init__
函数:模型初始化
def __init__(self, num_classes=1, input_channels=3, c_list=[8,16,24,32,48,64], split_att='fc', bridge=True):
super().__init__()
self.bridge = bridge
# 编码器部分(6层)
self.encoder1 = nn.Sequential(nn.Conv2d(input_channels, c_list[0], 3, stride=1, padding=1))
self.encoder2 = nn.Sequential(PVMLayer(input_dim=c_list[0], output_dim=c_list[1]))
self.encoder3 = nn.Sequential(PVMLayer(input_dim=c_list[1], output_dim=c_list[2]))
self.encoder4 = nn.Sequential(PVMLayer(input_dim=c_list[2], output_dim=c_list[3]))
self.encoder5 = nn.Sequential(PVMLayer(input_dim=c_list[3], output_dim=c_list[4]))
self.encoder6 = nn.Sequential(PVMLayer(input_dim=c_list[4], output_dim=c_list[5]))
# 解码器部分(5层)
self.decoder1 = nn.Sequential(PVMLayer(input_dim=c_list[5], output_dim=c_list[4]))
self.decoder2 = nn.Sequential(PVMLayer(input_dim=c_list[4], output_dim=c_list[3]))
self.decoder3 = nn.Sequential(PVMLayer(input_dim=c_list[3], output_dim=c_list[2]))
self.decoder4 = nn.Sequential(PVMLayer(input_dim=c_list[2], output_dim=c_list[1]))
self.decoder5 = nn.Sequential(PVMLayer(input_dim=c_list[1], output_dim=c_list[0]))
# 注意力桥(可选) 下文重点解释注意力代码
if bridge:
self.scab = SC_Att_Bridge(c_list, split_att)
print('SC_Att_Bridge was used')
# 分组归一化层
self.ebn1 = nn.GroupNorm(4, c_list[0]) # 输入通道分组归一化(每组4通道)
self.ebn2 = nn.GroupNorm(4, c_list[1])
# 最终输出层
self.final = nn.Conv2d(c_list[0], num_classes, kernel_size=1)
self.apply(self._init_weights) # 权重初始化
编码器-解码器结构:编码器包含 6 个 PVMLayer
(基于 Mamba
的模块),解码器包含 5 个 PVMLayer
,通过下采样和上采样实现特征提取与重建。
注意力桥:SC_Att_Bridge
结合通道和空间注意力,增强特征表达能力。
分组归一化:使用 GroupNorm
替代 BatchNorm
,避免小批量数据下的统计偏差。
2.2 _init_weights
函数:权重初始化
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02) # 截断正态分布初始化,防止梯度爆炸
if m.bias is not None:
nn.init.constant_(m.bias, 0) # 偏置初始化为0
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) # He初始化
if m.bias is not None:
m.bias.data.zero_()
这段代码核心作用:
- 对全连接层使用截断正态分布初始化,避免梯度爆炸。
- 对卷积层使用 He 初始化,加速模型收敛。
2.3 Channel_Att_Bridge
类:通道注意力机制
Channel_Att_Bridge
模块通过对每个通道的重要性进行建模,来增强网络对重要特征的关注。该模块的工作流程包括以下步骤:
- 全局特征压缩:使用全局平均池化提取每个通道的全局信息。
- 注意力权重生成:使用全连接或卷积层生成每个通道的注意力权重。
- 特征增强:通过
Sigmoid
激活函数归一化注意力权重,并与原始特征图进行逐元素相乘,突出重要的特征。
这种通道注意力
机制有助于模型学习到哪些通道对当前任务最为关键,从而优化性能并减少不相关信息的影响。
class Channel_Att_Bridge(nn.Module):
def __init__(self, c_list, split_att='fc'):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1) # 全局平均池化
self.get_all_att = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False)
# 定义注意力权重生成层(全连接或卷积)
self.att1 = nn.Linear(sum(c_list)-c_list[-1], c_list[0]) if split_att == 'fc' else nn.Conv1d(...)
self.sigmoid = nn.Sigmoid()
def forward(self, t1, t2, t3, t4, t5):
# 拼接各层特征并生成注意力权重
att = torch.cat([self.avgpool(t) for t in [t1,t2,t3,t4,t5]], dim=1)
att = self.get_all_att(att.squeeze(-1).transpose(-1, -2))
# 扩展注意力权重至特征图尺寸
att1 = self.sigmoid(self.att1(att)).expand_as(t1)
# ...(其他注意力权重生成,可以查看源码)
return att1, att2, att3, att4, att5
-
全局特征压缩:
self.avgpool = nn.AdaptiveAvgPool2d(1) # 全局平均池化
-
通过
AdaptiveAvgPool2d(1)
,将每个输入的特征图(t1
,t2
,t3
,t4
,t5
)经过全局平均池化处理,压缩成一个固定大小的输出。这是通过计算每个通道的均值来压缩图像信息的步骤。 -
avgpool
是为了从每个特征图的通道中提取全局信息,也就是计算每个通道上的全局平均值。这将产生一个包含每个通道重要性的标量值的张量。
-
-
注意力权重生成:
self.att1 = nn.Linear(sum(c_list)-c_list[-1], c_list[0]) if split_att == 'fc' else nn.Conv1d(...)
-
在这个部分
att1
被定义为一个全连接层或卷积层,具体取决于split_att
参数的值。- 如果
split_att == 'fc'
,则使用全连接层(nn.Linear
)来计算每个通道的注意力权重。 - 如果
split_att
为其他值(例如默认值),则使用卷积层(nn.Conv1d
),用于生成通道级别的权重。
- 如果
-
sum(c_list)-c_list[-1]
:这是将输入通道数的总和(c_list
中所有通道数的和)减去最后一个通道数,以得到输入给注意力层的总通道数。 -
c_list[0]
:这是全连接层或卷积层的输出通道数,表示生成的注意力特征的维度。
-
-
特征增强:
att1 = self.sigmoid(self.att1(att)).expand_as(t1)
- 经过全连接层或卷积层后,通过
sigmoid
激活函数对输出的注意力权重进行归一化处理。 Sigmoid
激活函数:将输出限制在 0 到 1 之间,确保每个通道的注意力权重都是一个合理的比例。通过这种方式,可以确定哪些通道对模型来说更重要,哪些通道应该被抑制。expand_as(t1)
:将计算得到的注意力权重扩展成与输入特征图t1
相同的尺寸,使得每个通道的权重可以与原始特征图逐元素相乘。
- 经过全连接层或卷积层后,通过
-
最终注意力权重生成:
return att1, att2, att3, att4, att5
-
模块的输出是经过注意力加权后的五个输入特征图
att1
,att2
,att3
,att4
,att5
。每个特征图都被乘以对应的注意力权重(att1
,att2
, 等)。 -
这些加权后的特征图将被返回并传递到后续的网络层,突出显示重要的特征。
-
2.4 Spatial_Att_Bridge
类:空间注意力机制
空间注意力机制
的目的是通过在空间维度上加权每个像素的贡献,以便网络可以专注于图像中更重要的区域,而不是对所有区域给予相同的关注,以突出目标区域,抑制背景噪声。
Spatial_Att_Bridge
类通过空间注意力机制,结合了通道特征的全局平均信息和最大值信息,通过空洞卷积
扩展了感受野,在空间上为每个像素生成了加权系数,从而提升了模型在空间维度上的感知能力。其主要步骤如下:
- 计算通道级特征:通过计算每个像素的通道平均值和最大值,捕捉了通道上的重要信息。
- 拼接信息:将通道平均和最大值的信息拼接在一起,增强特征的表达能力。
- 扩展感受野:使用空洞卷积来扩大感受野,让模型能够捕捉更大范围的上下文信息。
- 生成空间注意力:通过卷积和Sigmoid激活生成空间注意力图,对特征图进行加权,突出图像中的重要区域。
class Spatial_Att_Bridge(nn.Module):
def __init__(self):
super().__init__()
self.shared_conv2d = nn.Sequential(
nn.Conv2d(2, 1, kernel_size=7, padding=9, dilation=3), # 空洞卷积扩大感受野
nn.Sigmoid()
)
def forward(self, t1, t2, t3, t4, t5):
att_list = []
for t in [t1,t2,t3,t4,t5]:
avg_out = torch.mean(t, dim=1, keepdim=True) # 通道平均
max_out, _ = torch.max(t, dim=1, keepdim=True) # 通道最大值
att = torch.cat([avg_out, max_out], dim=1) # 拼接后生成空间注意力
att = self.shared_conv2d(att)
att_list.append(att)
return att_list
-
多尺度感受野:空洞卷积
self.shared_conv2d = nn.Sequential( nn.Conv2d(2, 1, kernel_size=7, padding=9, dilation=3), # 空洞卷积扩大感受野 nn.Sigmoid() )
-
空洞卷积:
dilation=3
的卷积层即空洞卷积(Dilated Convolution),它的作用是扩大卷积核的感受野,而不增加计算量。具体而言,空洞卷积通过在卷积核的元素之间引入空洞(即增加元素之间的间隔),从而扩展感受区域。这种做法可以让模型获得更多的上下文信息(即更大范围的图像信息),但计算量并不会成倍增加。 -
感受野扩展:通过空洞卷积,模型可以从更广阔的区域获取信息,而不会对计算性能造成大的影响。在图像分割任务中,感受野的扩大使得模型能考虑更远处的信息,这对处理复杂图像、识别更大范围的目标至关重要。
-
-
双路特征融合:均值与最大值
avg_out = torch.mean(t, dim=1, keepdim=True) # 通道平均 max_out, _ = torch.max(t, dim=1, keepdim=True) # 通道最大值 att = torch.cat([avg_out, max_out], dim=1) # 拼接后生成空间注意力
-
通道平均与最大值:
- 通道平均(
torch.mean(t, dim=1, keepdim=True)
):计算每个像素在各个通道上的平均值,即在空间维度(高度和宽度)上对所有通道进行平均。这个操作可以提取通道级的全局信息。 - 通道最大值(
torch.max(t, dim=1, keepdim=True)
):计算每个像素在各个通道上的最大值,这通常有助于捕获图像中的最显著特征区域。
- 通道平均(
-
拼接操作(
torch.cat([avg_out, max_out], dim=1)
):将通道平均和最大值的结果在通道维度(dim=1
)上拼接,生成一个包含更多信息的特征图。这一操作可以让模型同时关注到平均特征和最大特征,从而提升对空间信息的敏感度。
-
-
生成空间注意力权重
att = self.shared_conv2d(att)
-
空间注意力生成:将拼接后的特征图通过空洞卷积层(
self.shared_conv2d
)进一步处理。通过卷积操作,模型能够自动学习如何将空间上下文信息映射为注意力权重。 -
Sigmoid激活函数:通过
Sigmoid
激活函数,将卷积输出的权重限制在 [0, 1] 之间,确保输出的注意力图像可以与原始输入进行有效的加权融合。
-
-
返回最终空间注意力权重
att_list = att_list.append(att)
- 生成空间注意力图:经过卷积层和 Sigmoid 激活函数后的结果(即每个输入特征图的空间注意力权重)会被加入到一个列表
att_list
中。最终,att_list
中包含了对所有输入特征图(t1
到t5
)的空间注意力加权结果。
- 生成空间注意力图:经过卷积层和 Sigmoid 激活函数后的结果(即每个输入特征图的空间注意力权重)会被加入到一个列表
2.5 SC_Att_Bridge
类:双注意力融合
SC_Att_Bridge
是一个将空间注意力和通道注意力结合起来的模块,能够更有效地关注图像中的关键区域,并且根据通道的重要性进一步优化特征表达。这种方法有助于提升图像分割任务的性能,特别是在复杂的背景和目标区域之间的差异性较大时。
空间注意力优先:首先通过 Spatial_Att_Bridge
提取空间注意力,调整特征图中的空间信息,突出重要区域并抑制不相关区域。
通道注意力细化:接着,通过 Channel_Att_Bridge
提取通道注意力,根据每个通道的重要性进一步优化特征表达,确保模型更加关注关键的通道信息。
class SC_Att_Bridge(nn.Module):
def __init__(self, c_list, split_att='fc'):
super().__init__()
self.catt = Channel_Att_Bridge(c_list, split_att) # 通道注意力
self.satt = Spatial_Att_Bridge() # 空间注意力
def forward(self, t1, t2, t3, t4, t5):
# 空间注意力增强
satt_weights = self.satt(t1, t2, t3, t4, t5)
t1, t2, t3, t4, t5 = [satt * t for satt, t in zip(satt_weights, [t1,t2,t3,t4,t5])]
# 通道注意力增强
catt_weights = self.catt(t1, t2, t3, t4, t5)
t1, t2, t3, t4, t5 = [catt * t for catt, t in zip(catt_weights, [t1,t2,t3,t4,t5])]
return t1, t2, t3, t4, t5
-
空间注意力增强:调整特征图中的空间信息
satt_weights = self.satt(t1, t2, t3, t4, t5) t1, t2, t3, t4, t5 = [satt * t for satt, t in zip(satt_weights, [t1,t2,t3,t4,t5])]
- 空间注意力优先:在
SC_Att_Bridge
中,空间注意力优先被应用到输入特征图t1
,t2
,t3
,t4
,t5
上。首先,模型通过Spatial_Att_Bridge
提取出每个特征图的空间注意力权重(即satt_weights
)。 - 加权融合:空间注意力权重会与原始特征图逐元素相乘(
satt * t
)。这个步骤会根据每个区域的空间重要性调整特征图,突出关键信息并抑制不相关的背景区域。通过空间注意力的作用,模型能够更好地聚焦在目标区域,忽略无关的区域,从而提高分割精度。
- 空间注意力优先:在
-
通道注意力细化:根据通道重要性进一步优化特征表达
catt_weights = self.catt(t1, t2, t3, t4, t5) t1, t2, t3, t4, t5 = [catt * t for catt, t in zip(catt_weights, [t1,t2,t3,t4,t5])]
- 通道注意力细化:在空间注意力已经增强特征图后,接下来会通过通道注意力机制进一步优化特征表达。通过
Channel_Att_Bridge
模块,模型会计算出每个通道的重要性权重(即catt_weights
)。 - 加权融合:这些通道注意力权重将与原始特征图相乘,进一步细化每个通道的表达。通过这种方式,模型能够学习哪些通道包含更多有用信息,从而在通道级别调整特征的贡献。
- 通道注意力细化:在空间注意力已经增强特征图后,接下来会通过通道注意力机制进一步优化特征表达。通过
2.6 PVMLayer
类
与 Mamba
类 大同小异,可参考博主的另一篇关于 Mamba
的文章,在这里不做过多赘述。
class PVMLayer(nn.Module):
def __init__(self, input_dim, output_dim, d_state=16, d_conv=4, expand=2):
super().__init__()
self.norm = nn.LayerNorm(input_dim) # 层归一化
self.mamba = Mamba(
d_model=input_dim//4, # 拆分输入为4部分并行处理
d_state=d_state, # 状态空间维度
d_conv=d_conv, # 局部卷积核大小
expand=expand # 扩展因子
)
self.proj = nn.Linear(input_dim, output_dim) # 线性投影
self.skip_scale = nn.Parameter(torch.ones(1)) # 跳跃连接权重
def forward(self, x):
B, C = x.shape[:2]
x_flat = x.reshape(B, C, -1).transpose(1, 2) # 展平为序列
x_norm = self.norm(x_flat)
# 分块并行处理
x1, x2, x3, x4 = torch.chunk(x_norm, 4, dim=2)
x_mamba = torch.cat([
self.mamba(x1) + self.skip_scale * x1,
# ...(其他块处理)
], dim=2)
x_mamba = self.proj(x_mamba)
return x_mamba.transpose(1, 2).reshape(B, -1, *x.shape[2:]) # 恢复空间维度
- 序列化处理:将图像特征展平为序列,适配 Mamba 的序列建模能力。
- 分块并行:将输入拆分为4块,独立处理后再拼接,提升计算效率。
- 残差连接:通过可学习的
skip_scale
平衡原始特征与 Mamba 输出。
若有想交流的小伙伴可以私信我,欢迎大家!