论文信息
题目:MixFormer: A Mixed CNN-Transformer Backbone for Medical Image Segmentation
MixFormer:一种用于医学图像分割的混合CNN-Transformer骨干网络
作者:un Liu, Kunqi Li, Chun Huang, Hua Dong, Yusheng Song, and Rihui Li
论文创新点
- 混合CNN-Transformer骨干网络:本文提出了一种**混合CNN-Transformer(MixFormer)**特征提取骨干网络,在下采样过程中无缝集成了Transformer的全局上下文信息和CNN的局部细节信息。
- 多尺度空间感知融合(MSAF)模块:为了捕捉跨尺度的特征依赖关系,作者引入了多尺度空间感知融合(MSAF)模块。
- 混合多分支扩张注意力(MMDA)机制:在编码和解码阶段之间,作者提出了混合多分支扩张注意力(MMDA)机制,用于弥合语义差距并强调特定区域。
- 基于CNN的上采样方法:为了恢复低级特征并提高分割精度,作者采用了基于CNN的上采样方法。
摘要
Transformer通过自注意力机制在医学图像处理中取得了显著进展,能够建模长距离语义依赖关系,但其缺乏卷积神经网络(CNN)捕捉局部空间细节的能力。本文提出了一种基于混合CNN-Transformer(MixFormer)特征提取骨干网络的新型分割网络,旨在提升医学图像分割的效果。MixFormer网络在下采样过程中无缝集成了Transformer和CNN架构的全局和局部信息。为了全面捕捉跨尺度的视角,作者引入了多尺度空间感知融合(MSAF)模块,有效实现了粗粒度与细粒度特征表示之间的交互。此外,作者还提出了混合多分支扩张注意力(MMDA)模块,用于在编码和解码阶段之间弥合语义差距,同时强调特定区域。最后,作者采用基于CNN的上采样方法来恢复低级特征,显著提高了分割精度。通过在多个主流医学图像数据集上的实验验证,MixFormer表现出卓越的性能。在Synapse数据集上,该方法达到了82.64%的平均Dice相似系数(DSC)和12.67 mm的平均Hausdorff距离(HD)。在自动心脏诊断挑战(ACDC)数据集上,DSC达到了91.01%。在国际皮肤成像协作(ISIC)2018数据集上,模型的平均交并比(mIoU)为0.841,准确率为0.958,精确率为0.910,召回率为0.934,F1得分为0.913。在Kvasir-SEG数据集上,平均Dice为0.9247,mIoU为0.8615,精确率为0.9181,召回率为0.9463。在CVC-ClinicDB数据集上,平均Dice为0.9441,mIoU为0.8922,精确率为0.9437,召回率为0.9458。这些结果表明,MixFormer在分割性能上优于大多数主流分割网络,如CNN和其他基于Transformer的结构。
关键词
医学图像分割(SEG),混合卷积神经网络(CNN)-Transformer骨干网络,混合多分支扩张注意力(MMDA),多尺度空间感知融合(MSAF)。
I. 引言
医学图像分割(SEG)是医学图像处理中不可或缺的一步。通过提取器官和病变区域等不同组成部分的关键特征,该方法为临床诊断和病理研究提供了可靠的基础。然而,标记特定器官和病变区域的过程耗时、费力且效率低下,需要专业人员执行这些任务。由于数据集稀缺和准确标记位置的挑战,医学图像分割变得更加复杂。因此,该领域的研究广泛关注如何利用计算机视觉技术开发精确且高度鲁棒的医学图像分割系统,以应对数据稀缺和准确标记等复杂问题。
凭借其卓越的特征提取能力,卷积神经网络(CNN)近年来在医学图像分割应用中取得了巨大成功。例如,U-Net、3D-UNet、ResUNet、UNet++和DenseUNet在分割各种器官的精细医学任务中表现出色。尽管基于CNN的网络具有出色的分割能力,但由于卷积操作的感受野固有局限性,这些网络无法建立特征信息之间的全局依赖关系,导致难以与远程语义信息交互,从而导致上下文特征信息的捕捉不足。通过引入注意力机制和金字塔模型等技术,研究人员努力扩大CNN的感受野以更好地获取全局信息。然而,这些操作需要深度堆叠的卷积层和连续的下采样,随着网络层数的增加,可能导致梯度消失或爆炸问题。这可能会导致局部信息的丢失,尤其是在处理小数据集时,训练变得更加困难。
为了弥补卷积在建模长距离依赖关系方面的不足,研究人员将Transformer引入计算机视觉领域。Transformer最初用于自然语言处理(NLP)任务,其独特的自注意力机制能够根据输入内容自适应地调整感受野。Transformer在获取全局语义信息方面优于卷积操作。例如,Dosovitskiy等人使用Transformer进行图像识别工作,并取得了与其他最先进技术相当的结果。为了建模局部连接并解决视觉元素大规模变化的问题,Liu等人提出了一种基于移位窗口多头自注意力机制(SW-MSA)的分层Swin Transformer。然而,尽管Transformer在捕捉全局语义依赖关系方面表现出色,但在有效封装复杂细节和长序列计算效率方面仍面临挑战。最近,Mamba及其变体的表现引人注目,这些方法将结构化状态空间模型(SSMs)与Transformer的MLP机制结合为块,使这些模型能够沿序列长度维度选择性地传播或遗忘信息,从而提高了长序列的计算效率。然而,与Transformer的显式注意力机制相比,这些方法的SSM机制在处理长距离依赖关系时缺乏直观性,在某些密集分割任务中的性能有限。
最近的研究尝试将CNN和Transformer结合起来,以充分利用它们各自的优势,从而有效缓解单一模块的局限性。大多数主流的医学图像分割技术采用U形编解码器设计。尽管这种U形架构取得了相当大的成功,但分割性能仍有待优化。这主要是由于编码和解码阶段之间的语义差距没有得到全面解决。在涉及语义差距的场景中,由于深层特征对目标识别和浅层特征对边缘分割的双重重要性,有效融合特征变得具有挑战性。
综上所述,本文提出了一种名为MixFormer的新型混合CNN-Transformer骨干网络,用于医学图像分割。该框架不仅解决了上述挑战,还在多个评估指标上超越了当前的最先进方法。本文的主要创新和贡献总结如下:
- 在先进的MixFormer架构中,混合编码器网络在下采样过程的每个尺度上协同集成了Transformer的全局上下文信息和CNN的局部细节信息。
- 多尺度空间感知融合(MSAF)模型确保了粗粒度和细粒度特征表示之间的语义一致性。为了捕捉跨尺度特征依赖关系,该模块巧妙地融合了不同尺度上共识区域之间的信息。
- 作者引入了混合多分支扩张注意力(MMDA)机制来跳过连接步骤。该方法有效过滤了冗余语义信息,同时提升了关键语义特征。这一过程极大地减少了编码器和解码器之间的特征语义差距。
- 作者在五个主流医学图像分割数据集上进行了实验验证,展示了所提出的MixFormer的有效性和泛化能力。结果表明,MixFormer模型在分割性能上始终优于最先进的模型。
III. 方法
A. 网络结构
所提出的MixFormer遵循典型的U形架构,主要由编码器、解码器和跳跃连接组成。编码器由两部分组成:基于CNN-Transformer的混合特征提取骨干网络和MSAF模型。解码器由基本的CNN单元组成。此外,在跳跃连接部分引入了混合多分支扩张注意力(MMDA)模块。在该方法中,混合特征提取器用于从Transformer和CNN中获取多层视觉特征,分别包含全局和局部线索。MSAF机制进一步在多个尺度之间交互语义信息,以获得增强的特征表示。为了获得更好的上采样性能,作者在编码器的最深层实现了Res-Conv块,以进一步加强全局特征表示。最后,作者将MMDA机制集成到跳跃连接中,有效地将融合的语义特征传递到解码器,以进行最终的分割预测。
B. 混合特征提取骨干网络
在MixFormer中,作者使用了一种混合编码器框架,该框架由Swin Transformer和Res2Net50组成,用于在编码阶段提取特征。这使得编码器能够在每个尺度上有效捕捉局部细节和全局上下文信息。Res2Net50在提取多尺度特征方面优于传统的CNN。此外,Swin Transformer在定义的窗口内对图像块进行自注意力计算,其计算开销与图像大小呈线性关系,这与ViT在整个图像块集上进行自注意力计算形成对比。
1) Res2Net50模型
标准的卷积核只能从给定图像中提取相同尺度的特征,因此需要构建不同大小的卷积核以获得多尺度特征。Gao等人提出了Res2Net,专门设计用于获取多尺度特征图。对于包括ResNet和ResNeXt在内的许多CNN架构,瓶颈块是基本组件。Res2Net寻求一种替代设计,以在保持与瓶颈块相当的计算负担的同时,获得更强大的多尺度特征提取能力。Res2Net块与瓶颈块的区别如图3所示。
具体来说,在1×1卷积之后,输入特征图被均匀地分为S个子集,表示为 x i x_i xi,其中 i ∈ { 1 , 2 , … , S } i \in \{1, 2, \ldots, S\} i∈{1,2,…,S}。与输入特征图相比,每个子集 x i x_i xi具有相同的空间大小,但只有1/S的通道数。除了 x 1 x_1 x1,每个 x i x_i xi对应一个3×3卷积操作符,表示为 K i ( ⋅ ) K_i(\cdot) Ki(⋅)。 K i ( ⋅ ) K_i(\cdot) Ki(⋅)的结果表示为 y i y_i yi。然后, y i − 1 ( ⋅ ) y_{i-1}(\cdot) yi−1(⋅)与子集 x i x_i xi结合,并输入到 K i ( ⋅ ) K_i(\cdot) Ki(⋅)中。因此, y i y_i yi可以表示为:
y i = { x i , i = 1 K i ( x i ) , i = 2 K i ( x i + y i − 1 ) , 2 < i ≤ S y_i = \begin{cases} x_i, & i = 1 \\ K_i(x_i), & i = 2 \\ K_i(x_i + y_{i-1}), & 2 < i \leq S \end{cases} yi=⎩ ⎨ ⎧xi,Ki(xi),Ki(xi+yi−1),i=1i=22<i≤S
在本研究中,通过使用Res2Net50网络作为局部细节提取器,作者有效地分割了原始图像( H × W H \times W H×W),以获得不同分辨率的特征金字塔输出 F 1 F_1 F1到 F 4 F_4 F4。然后,这些输出的通道维度通过1×1卷积操作映射到与相应Swin Transformer块相同的嵌入维度,并作为补偿Res2Net模型在不同阶段缺失的长距离依赖关系输入到Swin Transformer模块中。在作者的方法中,有四个下采样阶段用于获取局部语义信息。每个阶段的特征图输出大小为 H / 4 × W / 4 H/4 \times W/4 H/4×W/4、 H / 8 × W / 8 H/8 \times W/8 H/8×W/8、 H / 16 × W / 16 H/16 \times W/16 H/16×W/16和 H / 32 × W / 32 H/32 \times W/32 H/32×W/32,通道数分别为 D D D、 2 D 2D 2D、 4 D 4D 4D和 8 D 8D 8D。
2) Swin Transformer模型
如图4所示,ViT对整个嵌入的图像块进行自注意力计算。典型的MSA模块中的一个主要问题是计算复杂度是图像空间维度的二次函数。Swin Transformer引入了W-MSA和SW-MSA技术来缓解这一限制。图5(b)显示了两个连续的Swin Transformer块,每个块由LayerNorm(LN)层、MSA模块、多层感知器(MLP)层和残差连接组成。此外,两个相邻的MSA块被替换为W-MSA和SW-MSA块。W-MSA的计算开销与图像大小呈线性关系,因为它在定义的窗口内进行自注意力计算。然而,由于这些窗口之间缺乏信息交换,其建模长距离依赖关系的能力有限。相比之下,SW-MSA引入了循环移位操作,将每个窗口向左上方向移动。这一策略使得各种不相邻的子窗口能够有效地相互交互。因此,Swin Transformer能够有效地建模上下文依赖关系并获取高效的分层特征表示。
在Swin Transformer中,输入特征图被划分为不重叠的窗口,每个窗口包含 M × M M \times M M×M(本文中设置为7)个图像块。作为(S)W-MSA的输入, Z l − 1 ∈ R L × D Z^{l-1} \in \mathbb{R}^{L \times D} Zl−1∈RL×D是一个长度为 L L L、维度为 D D D的图像块序列。基于这种窗口划分策略,Swin Transformer块的输入和输出可以表示为:
Z ^ l = W-MSA ( LN ( Z l − 1 ) ) + Z l − 1 Z l = MLP ( LN ( Z ^ l ) ) + Z ^ l Z ^ l + 1 = SW-MSA ( LN ( Z l ) ) + Z l Z l + 1 = MLP ( LN ( Z ^ l + 1 ) ) + Z ^ l + 1 \hat{Z}^l = \text{W-MSA}(\text{LN}(Z^{l-1})) + Z^{l-1} \\ Z^l = \text{MLP}(\text{LN}(\hat{Z}^l)) + \hat{Z}^l \\ \hat{Z}^{l+1} = \text{SW-MSA}(\text{LN}(Z^l)) + Z^l \\ Z^{l+1} = \text{MLP}(\text{LN}(\hat{Z}^{l+1})) + \hat{Z}^{l+1} Z^l=W-MSA(LN(Zl−1))+Zl−1Zl=MLP(LN(Z^l))+Z^lZ^l+1=SW-MSA(LN(Zl))+ZlZl+1=MLP(LN(Z^l+1))+Z^l+1
其中, Z ^ l \hat{Z}^l Z^l和 Z l Z^l Zl分别表示第 l l l个块的W-MSA和MLP的输出。类似地, Z ^ l + 1 \hat{Z}^{l+1} Z^l+1和 Z l + 1 Z^{l+1} Zl+1分别表示第 ( l + 1 ) (l+1) (l+1)个块的SW-MSA和MLP的结果。 LN ( ⋅ ) \text{LN}(\cdot) LN(⋅)表示层归一化操作。根据先前的工作,自注意力计算如下:
Q = z l W Q , K = z l W K , V = z l W V Attention ( Q , K , V ) = SoftMax ( Q K T d + B ) V Q = z^l W_Q, \quad K = z^l W_K, \quad V = z^l W_V \\ \text{Attention}(Q, K, V) = \text{SoftMax}\left(\frac{QK^T}{\sqrt{d}} + B\right) V Q=zlWQ,K=zlWK,V=zlWVAttention(Q,K,V)=SoftMax(dQKT+B)V
其中, Q Q Q、 K K K和 V ∈ R L × d V \in \mathbb{R}^{L \times d} V∈RL×d分别表示查询、键和值矩阵; d d d表示查询或键的维度。 W Q W_Q WQ、 W K W_K WK和 W V ∈ R D × d W_V \in \mathbb{R}^{D \times d} WV∈RD×d表示三个投影矩阵的可学习参数。此外, B ∈ R L × L B \in \mathbb{R}^{L \times L} B∈RL×L表示偏置矩阵。
为了弥补Res2Net模型在全局上下文信息方面的缺失,作者将局部细节特征图 F i F_i Fi( i = 1 , 2 , 3 , 4 i = 1, 2, 3, 4 i=1,2,3,4)划分为图像块,并将其输入到Swin Transformer中。Swin Transformer有四个阶段,每个阶段包含特定数量的块,能够建模长距离依赖关系。这使得作者能够获得四个阶段的特征表示 S 1 S_1 S1到 S 4 S_4 S4,其中包含全局上下文信息。此外,作者在每个阶段采用了图像块合并方法,将特征图的通道维度加倍,同时降低分辨率,然后将其输入到下一阶段的Swin Transformer块中。Swin Transformer四个阶段的特征图分辨率分别为 H / 4 × W / 4 H/4 \times W/4 H/4×W/4、 H / 8 × W / 8 H/8 \times W/8 H/8×W/8、 H / 16 × W / 16 H/16 \times W/16 H/16×W/16和 H / 32 × W / 32 H/32 \times W/32 H/32×W/32,通道数分别为 D ′ D' D′、 2 D ′ 2D' 2D′、 4 D ′ 4D' 4D′和 8 D ′ 8D' 8D′。最后,局部空间信息 F i F_i Fi( i = 1 , 2 , 3 , 4 i = 1, 2, 3, 4 i=1,2,3,4)和全局语义信息 S j S_j Sj( j = 1 , 2 , 3 , 4 j = 1, 2, 3, 4 j=1,2,3,4)在相应层次上连接,以获得四个尺度上的增强特征表示 L 1 L_1 L1到 L 4 L_4 L4。然后,这些特征表示被输入到MSAF模型中进行不同尺度之间的信息交互。
C. MSAF模型
已经证明,多尺度信号在分割具有复杂位置和大小变化的器官时极为有价值。为了解决多尺度特征之间的语义信息交互问题,作者利用了MSAF模块。该模块包括四个步骤:图像块匹配、尺度融合、尺度分割和图像块反转,这些步骤共同实现了跨尺度依赖关系的建立。MSAF模块的网络结构如图6所示。
具体来说,作者首先将所有特征图正则化到相同的通道维度 D ′ D' D′,并通过下采样操作将它们分配到具有相同颜色的边界框中(即红色)。如果 p j i p_j^i pji被视为第 i i i个尺度上的第 j j j个图像块,那么 p j + 1 i p_{j+1}^i pj+1i、 p j + 2 i p_{j+2}^i pj+2i和 p j + 3 i p_{j+3}^i pj+3i将是匹配的下采样图像块。通过这一步骤,作者能够在从第 i i i到第 ( i + 3 ) (i+3) (i+3)个尺度的四个连续特征映射上定位相关的空间感知图像块,保留最相关图像块的空间对应关系。
然后,作者通过尺度融合方法聚合不同尺度上的空间相关图像块,生成一系列标记。该过程可以描述为:
[ flatten ( p j 1 ) , … , flatten ( p j N ) ] → p j cat [\text{flatten}(p_j^1), \ldots, \text{flatten}(p_j^N)] \rightarrow p_j^{\text{cat}} [flatten(pj1),…,flatten(pjN)]→pjcat
其中, flatten ( ⋅ ) \text{flatten}(\cdot) flatten(⋅)表示拼接操作。由于这一系列标记缺乏2D空间信息,作者利用Transformer块来确保其空间结构的一致性。图像块之间的跨尺度空间感知依赖关系可以通过以下方程获得:
p ^ j cat = W-MSA ( LN ( p j cat ) ) + p j cat p ^ j cat = MLP ( LN ( p ^ j cat ) ) + p ^ j cat p ^ j cat = SW-MSA ( LN ( p ^ j cat ) ) + p ^ j cat p ^ j cat = MLP ( LN ( p ^ j cat ) ) + p ^ j cat \hat{p}_j^{\text{cat}} = \text{W-MSA}(\text{LN}(p_j^{\text{cat}})) + p_j^{\text{cat}} \\ \hat{p}_j^{\text{cat}} = \text{MLP}(\text{LN}(\hat{p}_j^{\text{cat}})) + \hat{p}_j^{\text{cat}} \\ \hat{p}_j^{\text{cat}} = \text{SW-MSA}(\text{LN}(\hat{p}_j^{\text{cat}})) + \hat{p}_j^{\text{cat}} \\ \hat{p}_j^{\text{cat}} = \text{MLP}(\text{LN}(\hat{p}_j^{\text{cat}})) + \hat{p}_j^{\text{cat}} p^jcat=W-MSA(LN(pjcat))+pjcatp^jcat=MLP(LN(p^jcat))+p^jcatp^jcat=SW-MSA(LN(p^jcat))+p^jcatp^jcat=MLP(LN(p^jcat))+p^jcat
接下来,作者使用尺度分割方法将获得的增强序列按照拼接顺序反转为图像块。该操作如下所示:
Inverse ( Split ( p ^ j cat ) ) → p ^ j 1 , … , p ^ j N \text{Inverse}(\text{Split}(\hat{p}_j^{\text{cat}})) \rightarrow \hat{p}_j^1, \ldots, \hat{p}_j^N Inverse(Split(p^jcat))→p^j1,…,p^jN
其中, Split ( ⋅ ) \text{Split}(\cdot) Split(⋅)操作是 flatten ( ⋅ ) \text{flatten}(\cdot) flatten(⋅)操作的逆操作。
最后,图像块反转将所有相同尺度的图像块组合成增强的特征图 p 1 p_1 p1到 p 4 p_4 p4,这些特征图与CNN特征金字塔输出 F 1 F_1 F1到 F 4 F_4 F4融合,并通过跳跃连接进一步输入到解码器阶段。
D. 混合多分支扩张注意力模型
神经网络中浅层和深层特征具有不同的表示能力。因此,迫切需要寻求一种解决方案,以有效弥合编码和解码阶段相应特征之间的语义差距,同时最小化由采样过程引发的信息丢失。为此,作者引入了一种混合多分支扩张注意力(MMDA)模型,从三个维度(即通道、空间和全局上下文)映射特征,综合优化提取特征的信息丢失。该方法有效地减少了提取特征中的信息丢失,从而帮助解码器准确恢复特征图的原始分辨率。作者将MSAF每一层(除了最深层的下采样层)提取的特征图与相应的CNN特征金字塔输出融合。然后,通过跳跃连接将该组合输入到MMDA模块中进行特征提取。最终,该方法实现了与相应解码器层的和谐特征融合。MMDA模型由三个单元组成,包括全局自注意力单元、通道注意力单元和空间注意力单元,如图7所示。
通道注意力单元(图7的第I部分)主要获取图像纹理信息,同时关注每个通道的特征交互。对于给定的输入特征图 F ∈ R C × H × W F \in \mathbb{R}^{C \times H \times W} F∈RC×H×W,通过MaxPooling和AvgPooling操作沿空间轴提取特征信息。获得的特征图 C M ∈ R C × 1 × 1 C_M \in \mathbb{R}^{C \times 1 \times 1} CM∈RC×1×1和 C A ∈ R C × 1 × 1 C_A \in \mathbb{R}^{C \times 1 \times 1} CA∈RC×1×1分别经过1×1卷积和LeakyReLU激活,然后进行逐元素加法操作。最后,使用Sigmoid函数获得通道注意力图。具体流程如下:
C M = LeakyReLU ( Conv ( MaxPool ( F ) ) ) C A = LeakyReLU ( Conv ( AvgPool ( F ) ) ) C ( F ) = Sigmoid ( C M ⊕ C A ) , C ( F ) ∈ R C × 1 × 1 C_M = \text{LeakyReLU}(\text{Conv}(\text{MaxPool}(F))) \\ C_A = \text{LeakyReLU}(\text{Conv}(\text{AvgPool}(F))) \\ C(F) = \text{Sigmoid}(C_M \oplus C_A), \quad C(F) \in \mathbb{R}^{C \times 1 \times 1} CM=LeakyReLU(Conv(MaxPool(F)))CA=LeakyReLU(Conv(AvgPool(F)))C(F)=Sigmoid(CM⊕CA),C(F)∈RC×1×1
其中, C ( F ) C(F) C(F)表示通道注意力图, ⊕ \oplus ⊕表示逐像素求和。
空间注意力单元(图7的第II部分)的主要目标是识别网络中最重要区域以进行进一步处理。受先前使用坐标注意力(CA)模块的工作启发,作者采用CA技术作为MMDA模型的空间注意力部分,以编码具有精确空间位置信息的特征表示。该方法分为两个子步骤:坐标信息嵌入和CA生成。具体来说,给定特征图 F ∈ R C × H × W F \in \mathbb{R}^{C \times H \times W} F∈RC×H×W,首先使用尺寸为 ( H , 1 ) (H, 1) (H,1)或 ( 1 , W ) (1, W) (1,W)的池化核沿水平和垂直坐标对每个通道进行编码。因此,双向坐标信息嵌入的输出如下:
F H = Avgpool ( H , 1 ) ( F ) , F H ∈ R C × H × 1 F W = Avgpool ( 1 , W ) ( F ) , F W ∈ R C × 1 × W F_H = \text{Avgpool}_{(H, 1)}(F), \quad F_H \in \mathbb{R}^{C \times H \times 1} \\ F_W = \text{Avgpool}_{(1, W)}(F), \quad F_W \in \mathbb{R}^{C \times 1 \times W} FH=Avgpool(H,1)(F),FH∈RC×H×1FW=Avgpool(1,W)(F),FW∈RC×1×W
在CA生成步骤中,首先将两个特征图 F H F_H FH和 F W F_W FW拼接,并使用共享的1×1卷积操作符生成包含水平和垂直方向空间信息的中间特征图 f f f。具体操作如下:
f = δ ( BN ( Conv ( Cat ( F H , F W ) ) ) ) , f ∈ R C / r × 1 × ( H + W ) f = \delta(\text{BN}(\text{Conv}(\text{Cat}(F_H, F_W)))), \quad f \in \mathbb{R}^{C / r \times 1 \times (H + W)} f=δ(BN(Conv(Cat(FH,FW)))),f∈RC/r×1×(H+W)
其中, δ ( ⋅ ) \delta(\cdot) δ(⋅)表示非线性激活函数, BN \text{BN} BN表示批归一化操作。缩减比 r r r通常用于最小化模型的计算复杂度(例如16)。
接下来,作者将 f f f沿空间维度分割为两个独立的特征表示 f H f_H fH和 f W f_W fW。然后,使用两个1×1卷积将它们转换为与输入特征图 F F F相同数量的通道。具体操作如下:
f H = Split ( f ) H , f H ∈ R C / r × H × 1 f W = Split ( f ) W , f W ∈ R C / r × W × 1 S H = Sigmoid ( Conv ( f H ) ) , S H ∈ R C × H × 1 S W = Sigmoid ( Conv ( f W ) ) , S W ∈ R C × 1 × W f_H = \text{Split}(f)_H, \quad f_H \in \mathbb{R}^{C / r \times H \times 1} \\ f_W = \text{Split}(f)_W, \quad f_W \in \mathbb{R}^{C / r \times W \times 1} \\ S_H = \text{Sigmoid}(\text{Conv}(f_H)), \quad S_H \in \mathbb{R}^{C \times H \times 1} \\ S_W = \text{Sigmoid}(\text{Conv}(f_W)), \quad S_W \in \mathbb{R}^{C \times 1 \times W} fH=Split(f)H,fH∈RC/r×H×1fW=Split(f)W,fW∈RC/r×W×1SH=Sigmoid(Conv(fH)),SH∈RC×H×1SW=Sigmoid(Conv(fW)),SW∈RC×1×W
其中, S H S_H SH和 S W S_W SW随后被扩展并分别用作注意力权重。最后,CA块的结果可以表示为:
S ( F ) = F × S H × S W , S ( F ) ∈ R C × H × W S(F) = F \times S_H \times S_W, \quad S(F) \in \mathbb{R}^{C \times H \times W} S(F)=F×SH×SW,S(F)∈RC×H×W
其中, S ( F ) S(F) S(F)表示最终的空间注意力图。
全局注意力单元(图7的第III部分)利用Swin Transformer块结构构建全局上下文依赖关系之间的相关性。该操作可以弥补空间和通道注意力单元的信息丢失。首先,输入特征图 F ∈ R C × H × W F \in \mathbb{R}^{C \times H \times W} F∈RC×H×W被划分为一系列图像块。由于图像块之间没有2D空间结构信息,因此需要将图像块输入到(S)W-MSA机制中,以计算图像块之间的相似性表示关系,并从中恢复位置信息。最后,增强的特征序列被转换为全局注意力图 G ( F ) ∈ R C × H × W G(F) \in \mathbb{R}^{C \times H \times W} G(F)∈RC×H×W。该步骤的细节可以在公式(2)中找到。
最后,作者将 S ( F ) S(F) S(F)和 C ( F ) C(F) C(F)相乘,并进一步加上 G ( F ) G(F) G(F),得到混合多分支扩张注意力图 MMDA ( F ) ∈ R C × H × W \text{MMDA}(F) \in \mathbb{R}^{C \times H \times W} MMDA(F)∈RC×H×W:
MMDA ( F ) = [ S ( F ) ⊗ C ( F ) ] ⊕ G ( F ) \text{MMDA}(F) = [S(F) \otimes C(F)] \oplus G(F) MMDA(F)=[S(F)⊗C(F)]⊕G(F)
其中, ⊗ \otimes ⊗表示使用Python广播机制的乘法, ⊕ \oplus ⊕表示逐元素求和。
E. 解码器
为了增强有效感受野并使其更容易混合远距离上下文信息,融合后的特征图被输入到最深层的下采样层的Res-Conv块中。如图2所示,Res-Conv块由 L L L层(本文中设置为7)组成,通过深度卷积(即核大小 k = 7 k=7 k=7)和点卷积(即核大小 k = 1 k=1 k=1)混合空间位置和通道位置,进一步获取上下文依赖关系。此外,深度卷积中的组通道数等于输入特征图的通道数,并且每个卷积操作都应用了高斯误差线性单元(GeLU)激活和批归一化,如下所示:
f l = BN ( σ 1 ( DepthwiseConv ( f l − 1 ) ) ) + f l − 1 f l = BN ( σ 1 ( PointwiseConv ( f l ) ) ) f_l = \text{BN}(\sigma_1(\text{DepthwiseConv}(f_{l-1}))) + f_{l-1} \\ f_l = \text{BN}(\sigma_1(\text{PointwiseConv}(f_l))) fl=BN(σ1(DepthwiseConv(fl−1)))+fl−1fl=BN(σ1(PointwiseConv(fl)))
其中, f l f_l fl表示Res-Conv块中第 l l l层的输出特征图, σ 1 \sigma_1 σ1表示GeLU激活, BN \text{BN} BN表示批归一化。 DepthwiseConv \text{DepthwiseConv} DepthwiseConv和 PointwiseConv \text{PointwiseConv} PointwiseConv表示深度可分离卷积的两个主要阶段,其网络架构如图8所示。
在 DepthwiseConv \text{DepthwiseConv} DepthwiseConv阶段(图8的(a)部分),每个卷积核被指定处理输入图像的单个通道,生成与输入通道数匹配的中间特征图。例如,当处理 5 × 5 × 3 5 \times 5 \times 3 5×5×3的彩色图像时,该过程生成三个大小为 5 × 5 × 3 5 \times 5 \times 3 5×5×3的中间特征图。然而,这种每个通道独立进行卷积的方法未能充分利用通道间特征相关性。因此,它忽略了同一空间位置上通道之间的有价值信息交互。为了解决这一局限性并获得最终输出特征图, PointwiseConv \text{PointwiseConv} PointwiseConv阶段(图8的(b)部分)是一个关键操作。在这里,使用大小为 1 × 1 × C 1 \times 1 \times C 1×1×C的卷积核,其中 C C C表示前一层特征图的通道数。通过 PointwiseConv \text{PointwiseConv} PointwiseConv过程,每个卷积核沿通道维度组合中间特征图,有效地生成新的特征图。该方法显著减少了参数数量和计算复杂度,同时保持了令人满意的网络性能。
最后,作者采用双线性插值方法对特征图进行上采样。解码器由四个阶段组成,每个阶段包括一个典型的 3 × 3 3 \times 3 3×3卷积块、一个 2 × 2 2 \times 2 2×2上采样操作和ReLU激活函数。解码器首先接收低分辨率( H / 32 × W / 32 H/32 \times W/32 H/32×W/32)的融合特征图,经过四次连续的上采样操作后生成高分辨率( H × W H \times W H×W)的特征图。然后,重建的特征图通过分割头中的 3 × 3 3 \times 3 3×3卷积生成最终的分割结果,其中通道数与预测类别数相同。
F. 总损失
在本研究中,损失函数采用了交叉熵损失和Dice损失函数,以加速模型的收敛。较小的损失函数表示预测值与真实值更接近。模型的总损失函数如下:
L Total = λ 1 L Cross-entropy + λ 2 L Dice L_{\text{Total}} = \lambda_1 L_{\text{Cross-entropy}} + \lambda_2 L_{\text{Dice}} LTotal=λ1LCross-entropy+λ2LDice
其中, λ 1 \lambda_1 λ1和 λ 2 \lambda_2 λ2是关于损失函数的可调超参数(本文中设置为 λ 1 = λ 2 = 0.5 \lambda_1 = \lambda_2 = 0.5 λ1=λ2=0.5)。此外,Dice损失如下:
L Dice = 1 − 2 ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ L_{\text{Dice}} = 1 - \frac{2|X \cap Y|}{|X| + |Y|} LDice=1−∣X∣+∣Y∣2∣X∩Y∣
其中, X X X表示预测结果, Y Y Y表示真实标签。Dice损失用于衡量预测结果与真实标签之间的重叠程度。交叉熵损失如下:
L Cross-entropy = − y log y ′ − ( 1 − y ) log ( 1 − y ′ ) L_{\text{Cross-entropy}} = -y \log y' - (1 - y) \log(1 - y') LCross-entropy=−ylogy′−(1−y)log(1−y′)
其中, y ′ y' y′表示模型的输出, y y y表示真实标签。可以看出,当 y = 1 y=1 y=1时,预测值越接近1,损失函数越小,反之亦然。
IV. 实验
作者在五个主流的公开医学图像分割数据集上进行了多种实验,以验证所提出的MixFormer方法的有效性。实验数据集、实验设置、对比实验和消融实验将在下面详细描述。
A. 数据集和实现细节
实验中使用了五个公开的医学图像分割数据集,包括Synapse多器官分割挑战(Synapse)、自动心脏诊断挑战(ACDC)、国际皮肤成像协作(ISIC)数据集以及息肉分割数据集,包括Kvasir-SEG数据集和CVC-ClinicDB数据集。
- Synapse数据集:包含来自30个病例的3779张腹部临床CT轴向图像,每个病例包含85-198张512×512像素的切片,体素空间分辨率为 ( [ 0.54 − 0.54 ] × [ 0.98 − 0.98 ] × [ 2.5 − 5.0 ] ) ([0.54-0.54] \times [0.98-0.98] \times [2.5-5.0]) ([0.54−0.54]×[0.98−0.98]×[2.5−5.0]) mm³。其中18个样本(2212张轴向切片)用作训练集,12个样本用作测试集(1567张轴向切片)。研究中检查了八个腹部器官(主动脉、胆囊、脾脏、左肾、右肾、肝脏、胰腺和胃)。
- ACDC数据集:包含来自不同患者的100个心脏MRI病例。每个患者的MRI图像中,左心室(LV)、右心室(RV)和心肌(MYO)由医生精确标记,每个病例包含18-20张切片。作者将数据集分为70个训练样本、10个验证样本和20个测试样本。
- ISIC 2018数据集:来自ISIC2018挑战赛,该比赛专注于通过三个图像分析任务检测和分析黑色素瘤,包括病变分割、属性检测和疾病分类。作者选择了病变分割任务数据集,其中训练图像包含2596张皮肤镜图像。采用五折交叉验证方法将2596张训练图像随机分为训练集和验证集进行实验,并从五次交叉验证中获得的分数取平均值以获得最终结果。
- Kvasir-SEG数据集:包含1000张息肉图像及其对应的真实标签。与其他数据集相比,该数据集中的图像空间维度从332×487到1920×1072不等,图像中出现的息肉的大小和形状也各不相同。其中有700个大于160×160的大息肉,48个小于64×64的小息肉,以及323个中等大小的息肉。在本文的实验设置中,800张图像用于训练,100张图像用于验证,100张图像用于测试。
- CVC-ClinicDB数据集:也称为CVC-612数据集,包含来自25个结肠镜检查视频的612张息肉图像,所有图像样本的空间维度为288×384。同样,数据集按8:1:1的比例分为训练集、验证集和测试集。
在提出的MixFormer中,作者将所有数据集的输入图像和Swin Transformer的图像块大小设置为224×224和4。对于Synapse数据集和ACDC数据集,模型以切片方式表示每个3D体素,然后将预测的2D切片堆叠以重建3D预测结果进行测试。此外,作者采用了简单的数据增强技术,如随机旋转和翻转,以防止过拟合并确保数据多样性。然而,对于弱边界边缘分割数据集,包括Kvasir-SEG、CVC-ClinicDB和ISIC 2018数据集,实验遵循[56]和[57]的策略,并在每个训练周期中对输入图像和真实标签对使用以下随机增强方法:1)高斯模糊,核大小为25×25;标准差从[0.001, 2]均匀采样;2)颜色抖动:亮度系数从[0.6, 1.4]均匀采样,对比度系数从[0.5, 1.5]均匀采样,饱和度系数从[0.75, 1.25]均匀采样,色调系数从[0.99, 1.01]均匀采样;3)水平和垂直翻转,概率均设置为0.5;4)仿射变换,旋转角度从[-180°, 180°]均匀采样,水平和垂直平移幅度从[-28, 28]均匀采样,缩放幅度从[0.5, 1.5]均匀采样,剪切角度从[-22.5°, 22°]均匀采样。需要注意的是,所有四种增强方法都应用于输入图像,而仅使用三种、四种和五种增强方法应用于相应的分割图。
在这里,作者提供了五个数据集在批量大小(BS)、学习率(LR)、最大训练周期(EP)、优化器(OPT)、动量(Momen.)和权重衰减(WD)方面的信息,如表I所示。此外,在本研究中,所有神经网络均在Windows 10桌面上使用开源机器学习框架PyTorch开发和实现,其中包括两块NVIDIA GeForce RTX 3090显卡,总共48 GB的GPU内存。表II描述了所提出模型的最终配置。
B. 评估指标
对于Synapse数据集,作者使用平均Dice相似系数(DSC)和平均Hausdorff距离(HD)作为分割结果的评估指标。同样,ACDC数据集也使用了平均DSC。特别是,DSC用于衡量两个区域之间的相似程度,其取值范围为0到1。DSC值与两个区域的相似度成正比。另一方面,HD可用于衡量两个区域之间的接近程度,两个区域越接近,HD值越小。
对于ISIC 2018数据集,作者使用平均交并比(mIoU)、精确率、召回率、F1、准确率和平均绝对误差(MAE)来评估分割结果。同样,Kvasir-SEG数据集和CVC-ClinicDB数据集也使用了这些指标,但排除了准确率。
C. 对比实验
为了验证所提出的MixFormer的有效性,作者将其与一些最先进的方法在Synapse数据集、ACDC数据集、ISIC 2018数据集、Kvasir-SEG数据集和CVC-ClinicDB数据集上进行了比较。
1) Synapse数据集的结果
作者进行了两组实验来验证MixFormer模型在Synapse数据集上的分割性能。作者的方法与基于CNN的模型(如U-Net、Att-UNet、R50 Att-UNet、R50-UNet和U-Net++)进行了比较。此外,作者还评估了一些结合CNN和ViT的ViT和混合技术,包括ViT、R50-ViT、TransUnet、MedT和SwinUnet。这两组实验的结果如表III所示。
2) ACDC数据集的结果
为了进一步验证所提出模型的鲁棒性和泛化能力,作者通过迁移学习将Mixformer模型应用于ACDC数据集。对比实验包括基于CNN的方法(如U-Net、U-Net++和R50 Att-UNet)以及基于Transformer的方法(如R50-ViT、SwinUNet和TransUNet)。如表IV所示,所提出的MixFormer(平均DSC为91.01%)在DSC指标上优于U-Net、U-Net++、R50 Att-UNet、R50-ViT、SwinUNet和TransUNet。
3) ISIC 2018数据集的结果
作者还将MixFormer与几种先进的分割技术在ISIC 2018数据集上进行了比较,总结如表V所示。与基于CNN的方法(如Att-UNet和UNet++)相比,所提出的MixFormer利用Swin Transformer提取了包含更多长距离语义信息的视觉特征。与基于Transformer的方法(如MedT、SwinUNet和TransUNet)相比,作者的方法使用了包含多维信息的MMDA模块,在解码阶段生成更精细的局部空间信息和更显著的纹理特征。
4) 息肉数据集的结果
对于Kvasir-SEG和CVC-ClinicDB息肉数据集,作者还与其他最先进的方法(如U-Net、U-Net++、PraNet、ResUNet++、TransUNet和SwinUNet)进行了定性和定量实验。定量结果如表VI和表VII所示。
D. 消融实验
为了评估MixFormer模型中每个组件的有效性,作者在Synapse数据集上进行了多项消融实验。这些实验包括:1)不同CNN模型的消融实验;2)MSAF模块的有效性;3)跳跃连接数量的影响;4)输入大小的影响;5)混合特征提取骨干网络的影响。
1) 不同CNN骨干网络的比较
作者首先研究了不同CNN骨干网络在Synapse数据集上的效果,以展示作者方法的分割效率。特别是,作者关注了ResNet及其衍生网络在Synapse数据集上的表现。如表VIII所示,使用Res2Net50骨干网络获得了最佳分割结果(平均DSC:82.64%,平均HD:12.67 mm),优于ResNet50、ResNeXt50和SeResNet50。
2) MSAF模块的有效性
如表IX所示,MSAF模块在增强分割性能方面发挥了重要作用。具体来说,由于在不同阶段学习了空间交互信息,平均DSC得分从81.66%增加到82.64%,平均HD从21.77 mm减少到12.67 mm。
3) 跳跃连接数量的影响
接下来,作者评估了MMDA模块在编码器和解码器之间数量的变化对Synapse数据集分割性能的影响。作者实验记录了四种情况:1)无MMDA(0-MMDA);2)仅在下采样的第三阶段(最深层)和解码器的跳跃连接之间添加MMDA(1-MMDA);3)在下采样的第二和第三阶段与解码器的跳跃连接之间分别添加MMDA(2-MMDA);4)在三个连续的下采样阶段与解码器的跳跃连接之间添加MMDA(3-MMDA)。如表X所示,实验记录了MMDA模块从第三阶段(图像大小14×14)到第一阶段(图像大小56×56)的分割性能变化。当在三个连续阶段使用MMDA模块时(3-MMDA,平均DSC:82.64%,平均HD:12.67 mm),获得了最佳性能。
4) 输入大小的影响
最后,表XI显示了所提出的MixFormer在使用网络输入分辨率为224×224和384×384时的实验结果。随着输入大小从224×224增加到384×384,Swin Transformer的输入标记序列变得更大,而图像块大小保持为4,从而提高了模型的分割性能。尽管模型的分割精度有所提高,但整个网络的计算负担也大大增加。为了确保所提出方法的运行效率,本文使用224×224分辨率作为实验输入。
5) 混合特征提取骨干网络的影响
为了阐明局部和全局分支对模型性能的不同影响,作者使用Synapse数据集对混合特征提取骨干网络进行了消融研究。该研究的结果如表XII所示。结果表明,当Res2Net分支单独作为编码器的特征提取组件时,平均DSC和HD指标分别为79.36%和23.19 mm。这种较差的表现是由于CNN在捕捉全局上下文信息方面的局限性。相反,当Swin Transformer分支作为特征提取器时,平均DSC和HD指标仅达到76.73%和28.07 mm,反映了ViT在描绘局部细粒度特征方面相对于CNN架构的相对不足。令人印象深刻的是,所提出的MixFormer通过将局部-全局交互(LGI)模块作为其基础特征提取网络,实现了卓越的分割性能。这强调了协同整合局部和全局信息以实现精确目标分割的重要性。
E. 跨数据库实验
为了确保泛化能力,作者在一个已见数据集上训练了模型和其他最先进模型,然后在一个未见数据集上测试这些模型。在这里,作者使用Kvasir-SEG数据集训练模型,并在CVC-ClinicDB数据集上测试模型。同样,作者在相反设置下进行了这项研究,即在CVC-ClinicDB上训练并在Kvasir-SEG上测试。
1) CVC-ClinicDB上的泛化结果
在Kvasir-SEG上训练并在CVC-ClinicDB上测试的泛化结果如表XIII所示。对于未见数据集的实验测试,MixFormer达到了令人满意的mDice为87.93%,mIoU为80.43%,mPrecision为89.13%,mRecall为89.81%,MAE为2.21%。特别是,所提出的MixFormer在mDice上优于现有的最先进技术,超过了PraNet方法产生的83.88%的mDice。
2) Kvasir-SEG上的泛化结果
同样,表XIV显示了作者的方法在Kvasir-SEG上测试的结果。作者的模型在所有指标上均优于其他最先进方法,mDice为86.56%,mIoU为79.71%,mPrecision为90.19%,mRecall为87.49%,MAE为4.01%。尽管PraNet以76.67%的mIoU和83.87%的mDice获得了第二好的性能,但作者的方法在mDice上超过了它2.69%,在mIoU上超过了它3.04%。
F. 可解释性分析
在深度学习中,准确性和简单性(或可解释性)之间往往存在权衡。经典的基于规则或专家系统提供了高可解释性,但通常缺乏精度和鲁棒性。为此,研究人员一直在探索可解释性和准确性之间的平衡。Selvaraju等人提出了梯度加权类激活映射(Grad-CAM)技术,该技术利用梯度信息生成定位图,突出显示用于预测的图像关键区域。这不仅增进了作者对模型内部工作机制的理解,还有助于提高模型的可信度和接受度。所有这些工作都是在不改变网络架构的情况下完成的。
因此,为了验证所提出的MixFormer能够准确定位医学图像分割模型感兴趣的图像区域,作者使用Grad-CAM方法可视化中间层特征图。对于每个测试图像,图像由训练好的模型进行预测。图13和图14显示了结肠直肠息肉的可视化结果,其中不同颜色表示像素对分割结果的重要性,即模型对每个像素的敏感性。热图中较深的红色表示该区域对最终预测结果的贡献最大,其次是黄色表示次重要区域,而蓝色区域对结果的影响较小,可以视为模型中的冗余信息。
V. 结论
在本文中,作者提出了MixFormer,一种结合全局和局部信息的新型医学图像分割网络,具有出色的分割性能。作者方法的主要贡献分为三个方面。首先,作者建立并实现了一种基于CNN-Transformer的混合特征提取骨干网络,该网络在建模长距离语义交互和保持浅层特征细节之间取得了平衡。其次,作者利用MSAF模块获取不同尺度之间的语义信息连接,确保多尺度特征的丰富性和特征提取的一致性。最后,作者提出了一种用于解码器上采样阶段的MMDA,可以通过跳跃连接抑制冗余语义信息,并最小化信息丢失。在本研究中,作者在五种不同类型的医学图像数据集上进行了实验,所有结果都达到了卓越的分割水平,并展示了所提出网络的强大泛化能力。
本文的研究存在一些局限性。众所周知,模型预训练对基于Transformer的方法的性能有显著影响。未来,作者将探索如何端到端地预训练Transformer以完成医学图像分割任务。此外,实践中的大多数医学图像数据是3D的,而本文中的输入图像是2D的。未来的研究需要建立一个3D医学图像分割网络,以增强网络通过医学图像的层间相关特征识别目标对象的能力。最后,对于密集分割任务,特别是在处理多类别前景分割挑战时,模型需要牺牲一定数量的参数和浮点运算(FLOPs)以确保分割精度(所提出的MixFormer具有较高的参数数量和FLOPs)。在未来的工作中,作者旨在通过减少参数数量和FLOPs来优化模型,以适应特定的医学分割任务。源代码可在https://github.com/LKQ-GITHUB-CODE/MixFormer/tree/master获取。
声明
本文内容为论文学习收获分享,受限于知识能力,本文对原文的理解可能存在偏差,最终内容以原论文为准。本文信息旨在传播和学术交流,其内容由作者负责,不代表本号观点。文中作品文字、图片等如涉及内容、版权和其他问题,请及时与我们联系,我们将在第一时间回复并处理。