MIXED TRANSFORMER U-NET FOR MEDICAL IMAGE SEGMENTATION
论文源:https://github.com/Dootmaan/MT- UNet
摘要:尽管U - Net在医学图像分割任务中取得了巨大成功,但它缺乏显式建模长距离依赖关系的能力。因此,视觉转换器( Vision Transformers )因其天生具有的通过自注意力( Self- Attention,SA )捕获长距离相关性的能力,近年来成为一种可供选择的分割结构。然而,Transformer通常依赖于大规模的预训练,计算复杂度较高。此外,SA只能对单个样本内的自相似性进行建模,忽略了整体数据集的潜在相关性。为了解决这些问题,本文提出了一种新颖的Transformer模块,称为混合Transformer模块( Mixed Transformer Module,MTM ),用于同时进行类间和类内的相似度学习。MTM首先通过我们精心设计的局部-全局高斯加权自注意力( LGG-SA )来高效地计算自相似性。然后,通过外部注意力( External Attention,EA )挖掘数据样本之间的相互联系。利用MTM,本文构建了一个名为Mixed Transformer U-Net ( MT-UNet )的U型模型,用于精确的医学图像分割。两个不同的公共数据集上测试本文提出的方法,实验结果表明本文提出的方法比其他先进的方法取得了更好的性能。
所提方法的示意图如图1所示。该网络基于编码器-解码器结构,在解码时使用跳跃连接来保持低级特征。如图所示,MTMs只用于空间尺寸较小的较深层以减少计算成本,而上层仍然使用经典的卷积操作。这是因为我们希望关注初始层上的局部关系,因为它们包含更多的高分辨率细节。通过使用卷积,我们还可以在模型中引入一些结构先验,这对于尺寸相对较小的医学图像数据集是有帮助的。值得注意的是,对于所有的Transformer模块,均采用2阶卷积/反卷积核来实现下采样/上采样以及通道扩展/压缩。
所提出的MTM概述如图2所示。如前所述,MTM由LGG - SA和EA两部分组成。LGG - SA用于建模不同粒度的短程依赖和长程依赖,而EA用于挖掘样本间的相关性。该模块被提出来替代原始的Transformer编码器,因为它在视觉任务上具有更好的性能和更低的时间复杂度。
局部-全局高斯加权自注意力
LGG - SA完美体现了聚焦计算的思想。与传统SA对所有令牌同等关注不同,LGG - SA由于使用了Local - Global策略和高斯掩码,可以更专注于附近区域。实验证明LGG - SA可以提高模型性能,节省计算资源。该模块的详细设计如图3所示。
局部-全局自注意力
SA旨在捕获输入序列中所有实体之间的相互联系。为了实现这个目标,SA引入了三个矩阵,分别是key ( K ),query ( Q )和value ( V )。这三个矩阵是输入X的线性变换。然而,在计算机视觉中,近距离区域之间的相关性往往比远距离区域之间的相关性更重要,并且在计算注意力图时不需要为更远的区域花费相同的代价。因此,我们提出"局部-全局自注意力"。局部SA计算每个窗口内部的自相似度。然后,将每个窗口内部的令牌聚合为一个全局令牌,表示窗口的主要信息。对于聚合函数,本文尝试了跨步卷积、最大池化等方法,其中轻量级动态卷积( Lightweight Dynamic Convolution,LDConv ) 表现最好。在得到降采样后的整个特征图之后,可以以较小的代价执行全局SA。
高斯加权轴向注意力
与使用原始SA的LSA不同,本文为GSA提出了高斯加权轴向注意力机制( Gaussian Weighted Axial Attention,GWAA )。GWAA通过一个可学习的高斯矩阵来增强每个查询对附近令牌的感知,同时由于轴向注意力,GWAA具有较低的时间复杂度。假设
Q
∈
R
H
p
×
W
p
Q\in R^{{H \over p }×{ W \over p}}
Q∈RpH×pW表示由聚集步骤得到的查询,对于Q中的查询
q
i
,
j
q_i,_j
qi,j,本文定义
q
i
,
j
q_i,_j
qi,j与其对应的
K
i
,
j
K_i,_j
Ki,j和
V
i
,
j
V_i,_j
Vi,j之间的欧氏距离为
D
i
,
j
D_i,_j
Di,j,其中
K
i
,
j
K_i,_j
Ki,j和
V
i
,
j
V_i,_j
Vi,j是由聚集后第i行第j列的令牌生成的矩阵。令q和K的相似度为S( q , K),高斯权重为
e
−
D
i
2
,
j
2
σ
2
e^-D^2_i,_j \over 2σ^2
2σ2e−Di2,j
位置( i , j)处的最终输出可表示为:
z
i
,
j
,
=
e
−
D
2
i
,
j
2
σ
2
S
o
f
t
m
a
x
(
S
(
q
i
,
j
,
K
i
,
j
)
)
V
i
,
j
z_i,_j, = {e^−D^2{_i,_j } \over {2σ^2}} Softmax(S(q_i,_j , K_i,_j ))V_i,_j
zi,j,=2σ2e−D2i,jSoftmax(S(qi,j,Ki,j))Vi,j
并且可以简单地用w来表示
D
i
2
,
j
D^2_i,_j
Di2,j之前的系数因子,
w
D
i
2
,
j
wD^2_i,_j
wDi2,j也充当了相对位置偏差,通过它可以强调MTM中的位置信息。它提高了显式提供相对关系的模型性能,这是普通的绝对位置嵌入无法做到的。
总体而言,对于给定的具有n个体素的图像,当p固定时,LSA的时间复杂度为O(n)。相比之下,由于轴向注意力的存在,GSA的时间复杂度为O(n√n)。因此,本文提出的LGG-SA算法的整体复杂度为O(n√n)。
实验结果
在腹部多器官数据集上的效果
在ACDC数据集上的效果