ConvFormer:改进医学图像分割的即插即用CNN风格转换器
摘要:Transformer在医学图像分割中被广泛研究,以建立成对的长程依赖关系(像素之间的长程依赖关系)。然而,相对有限的注释良好的医学图像数据使transformer难以提取不同的全局特征,(这句话指的是在医学图像数据中,往往存在着相对较少的注释信息,这些注释信息通常用于描述图像中的不同结构、病变或特征。由于注释信息有限,传统的深度学习模型如Transformer在处理这些医学图像时可能会遇到一些挑战。
具体来说,"Transformer" 是一种深度学习模型,通常用于自然语言处理和计算机视觉任务。它在处理图像时会将图像划分为不同的区域(例如图像块或像素),然后将这些区域分别编码并进行信息交互,以提取全局特征。然而,在医学图像中,有限的注释信息可能导致模型难以理解图像中的不同结构或病变,因为模型缺乏足够的指导来区分这些特征。
这意味着模型可能会在提取和理解医学图像中的全局特征方面表现不佳,因为它缺乏足够的标签或注释来帮助它学习这些特征。为了克服这个问题,研究人员通常需要采用其他方法,如迁移学习、数据增强或使用辅助信息来提高医学图像处理的性能。
)导致注意力崩溃,注意力图变得相似甚至相同(通常指的是在使用注意力机制的深度学习模型中的一种问题,其中模型的自注意力权重在不同位置或通道之间变得过于相似,导致模型无法有效地捕获不同位置或通道之间的差异或相关性)。相比之下,卷积神经网络(CNNs)在小规模训练数据上具有更好的收敛特性,但感受野有限。现有的工作致力于探索CNN和transformer的组合,而忽略了注意力的崩溃(即模型倾向于忽略或过度强调某些区域,而忽视了其他区域。),使transformer的潜力被低估了。在本文中,我们建议构建CNN风格的transformer(ConvFormer),以促进更好的注意力收敛,从而提高分割性能。具体而言,ConvFormer由池化、CNN式自注意(CSA)和卷积前馈网络(CFFN)组成,对应于最基本形式的transformer中的标记化、自注意和前馈网络。与位置嵌入和标记化相比,ConvFormer采用2D卷积和最大池化来保存位置信息和减少特征大小。通过这种方式,CSA将2D特征图作为输入,并通过构造自注意矩阵作为具有自适应大小的卷积核来建立长程依赖性。在CSA之后,通过CFFN利用2D卷积进行特征细化。在多个数据集上的实验结果证明了ConvFormer作为即插即用模块的有效性,可以持续提高基于transformer的框架的性能.
1 介绍
得益于对长程依赖进行建模的突出能力,transformer已成为自然语言处理的事实标准[1]。与鼓励局部性、权重共享和平移等变性的卷积神经网络(CNNs)相比,transformers通过自关注层建立全局依赖性,为特征提取带来了更多可能性,并反过来打破了CNNs的性能上限[2,3,4,5,6]。
受此启发,Transformer被引入医学图像分割,并引起了广泛关注[7,8,9,10,11]。在视觉转换器中,每个医学图像首先被分割成一系列补丁,然后被投影到补丁嵌入的1D序列[4]。通过在补丁/令牌之间建立成对的交互,transformer应该聚合全局信息以进行稳健的特征提取。然而,transformer中的学习良好收敛全局依赖性是高度数据密集型的,在医疗成像数据相对有限的情况下,使transformer的有效性降低。
为了弄清楚transformer在医学图像分割中是如何工作的,我们在ACDC数据集上训练了四个最先进的基于transformer的模型[5,12,13,14](Transunet、Transfuse、 F AT-Net),并可视化了不同层的学习自注意矩阵,如图所示。1。对于所有方法,注意力矩阵在补丁之间趋于一致(即注意力崩溃[15]),尤其是在更深的层中。注意力崩溃更为明显,尤其是在CNN Transformer混合方法(即TransUNet、TransFuse和F AT Net)中。(目前存在的问题,或者本文要解决的问题)一方面,训练数据不足会使transformer学习次优的长期依赖性。另一方面,直接将CNN与transformer相结合会使网络偏向于CNN的学习,因为与transformer相比,CNN更容易实现收敛,尤其是在小规模训练数据上。因此,如何解决注意力崩溃问题,提高transformer的收敛性,对提高transformer的性能至关重要。
在这项工作中,我们提出了一个名为ConvFormer的即插即用模块,通过构建一个内核可扩展的CNN风格transformer来解决注意力崩溃问题。在ConvFormer中,2D图像可以直接建立足够的长程依赖性,而无需分割成1D序列。具体而言,对应于最基本形式的transformer中的标记化、自注意和前馈网络,ConvFormer分别由池化、CNN式自注意(CSA)和卷积前馈网络(CFFN)组成。对于输入图像/特征图,首先通过交替应用卷积和最大池来降低其分辨率。
然后,CSA通过自适应地生成可缩放卷积来为每个像素建立适当的依赖性,该卷积较小以包括局部性,或者较大以用于长程全局交互。最后,CFFN通过应用连续卷积来细化每个像素的特征。在五种最先进的基于变压器的方法的三个数据集上进行的大量实验验证了ConvFormer的有效性,优于现有的注意力崩溃解决方案。
2 相关工作
最近用于医学图像分析的基于变换器的方法主要集中于引入变换器,用于编码器中的鲁棒特征提取、跳跃连接中的跨尺度特征交互以及解码器中的多种特征融合[16,17,18,19]。关于解决医学成像中变压器注意力崩溃的研究还处于探索阶段。即使在自然图像处理中,注意力崩溃(通常存在于基于深度变换的模型的深层)也没有得到充分的研究。具体而言,周等人[15]开发了重新注意力,以重新生成自注意力矩阵,旨在增加其在不同层上的多样性。周等人[20]将自注意矩阵投影到高维空间中,并应用卷积来促进自注意矩阵的局部性和多样性。Touvron等人[21]提出重新加权自注意模块和前馈模块输出的通道,以促进变压器的收敛。
3 方法
视觉转换器(ViT)和ConvFormer之间的比较如图所示。最大的区别是我们的ConvFormer是在2D输入上进行的,而ViT应用于1D序列。具体来说,池化模块(pooling)用于取代ViT中的标记化,它很好地保留了位置和位置信息,而没有额外的位置嵌入。CNN风格的自注意(CSA)模块,即ConvFormer的核心,是为了取代ViT中的自注意模块,通过以类似于具有自适应和可扩展内核(自适应内核就是具有自适应内核的卷积允许模型根据输入数据的特征动态生成或调整卷积核)(可扩展内核卷积是一种允许卷积核的大小和形状随着输入数据的不同而改变的卷积方式。)
的卷积的方式构建自注意矩阵来构建长程依赖性而开发的。卷积前馈网络(CFFN)被开发用于细化与ViT中的前馈网络(FFN)相对应的每个像素的特征。由于池化模块可以通过调整最大池化时间来匹配输出大小,因此不采用上采样程序将ConvFormer的输出调整回输入大小。需要注意的是,ConvFormer是基于卷积实现的,它消除了第1节中分析的CNN和transformer之间的训练张力。ConvFormer的每个模块如下所述。
图2。基本形式的视觉转换器和ConvFormer的比较。CBR是卷积、批处理规范化和Relu的组合的缩写。为了简单起见,省略了多个头。
3.1 池化 vs Tokenization
池化模块的开发是为了实现标记化的功能(即,使输入适合通道维度上的transformer和整形以及在需要时减小输入大小),同时在标记化中不会丢失网格线中的细节。对于输入
首先应用核大小为3×3的卷积,然后进行批量归一化和Relu来捕获局部特征(这一些列是CBR)。然后,然后,对应于ViT中的每个补丁大小S,总计d = log2 S ,在池化模块中应用下采样操作以产生相同的分辨率。这里,每个下采样操作由核大小为2×2的最大池和3×3卷积、批处理归一化和Relu的组合组成。最后 通过池模块, Cm对应于ViT中的嵌入维度。
3.2 CNN式与顺序式自我关注
在ConvFormer中,远程依赖关系的建立依赖于cnn风格的自关注,它通过构造自定义的卷积核来为每个像素创建一个自适应的接受场。具体来说,对于每个像素,卷积核基于两个中间变量构造:
将3×3邻域中相邻像素的特征合并到的VIT中可学习的投影矩阵Q,K,V是否与Cp的嵌入维数相对应,然后是初始的自定义卷积核对于是通过计算余弦相似度来计算的。
并且很少情况下对应于ViT中的注意得分计算(约束为正)要么是积极的,要么是消极的)。然后,动态确定自定义的大小的卷积核,对于通过引入一个可学习的高斯距离图M。
这里,控制A和a接受野的可学习网络参数是控制接受野倾向的超参数,与接受野成正比。例如,在典型的环境下接受野只覆盖五个相邻的像素,接受野是全局性的。更大的是更有可能的A倾向于有一个全球性的接受领域。基于,计算由。这样,每个像素都有一个定制的可缩放卷积核,푗。通过将A与V相乘,CSA可以构建自适应的远程依赖关系,其中푉可以根据Eq.(1)类似地表示。最后,利用1 × 1卷积、批处理归一化和Relu的组合来整合从远程依赖关系中学习到的特征。
3.3 卷积与基本前馈网络
卷积前馈网络(convolution feedforward network, CFFN)是对CSA产生的特征进行细化,仅由1 × 1卷积、批处理归一化和Relu两种组合组成。通过替换ViT中的线性投影和层归一化,CFFN使ConvFormer完全基于CNN,避免了CNN-Transformer混合方法在训练过程中CNN和Transformer之间的争斗。
4 实验
4.1 数据集和实现细节
ACDC 自动心脏诊断挑战的公开可用数据集。左心室(lv)、心肌(MYO)和右心室(R V)的逐像素注释扫描共100次[22]。以下[12,17,18]分别使用70例、10例和20例进行训练、验证和测试。
在公开可用的ACDC(即fat - net[14])上,方法优于最先进的2D方法:在A vg中为91.46%。DSC)和ISIC(即Ms Red[28]: 90.25%。DSC)数据集。更全面的定量比较结果见补充资料。
ISIC 2018
一个公开可用的皮肤病变分割数据集。共有2594张带有像素级注释的皮肤镜病变图像[23,24]。
根据[25,26],将数据集随机分为2076张用于训练的图像和520张用于测试的图像。
ICH
局部采集的血肿分割数据集。三名放射科医师收集了99张CT扫描片,共2648片,并对其进行了注释。数据集按照7:1:2的比例随机分为训练集、验证集和测试集。
实现细节
为了公平的比较,在相同的设置下,使用或不使用ConvFormer对所有选定的基于核心变压器的基线进行了训练。所有模型都由亚当优化器进行训练比率为0.0001,批量大小为400轮4。数据增强包括随机旋转、缩放、对比度增强和伽马增强。
4.2 结果
ConvFormer可以作为即插即用模块工作,并取代基于变压器的基线中的普通变压器块。为了评估ConvFormer的有效性,我们选择了五种最先进的基于变压器的方法作为骨干,包括SETR[5]、TransUNet[12]、TransFuse[13]、F AT-Net[14]和Patcher[27]。SETR和Patcher采用纯变压器编码器,TransUNet、TransFuse和F - net采用CNN-Transformer混合编码器。此外,Re-attention[15]、LayerScale[21]和Refiner[20]这三种最先进的解决注意力崩溃的方法都配备了上述基于变压器的基线进行比较。
定量的结果。
表1总结了三个数据集上嵌入各种基于变压器的基线的ConvFormer的定量结果。ConvFormer在所有五个主干上实现一致的性能改进。与CNN-Transformer混合方法(即TransUNet, TransFuse和F - net)相比,ConvFormer更有利于纯变压器方法(即SETR和Patcher)。其中,SETR在ACDC、ISIC和ICH数据集上对Dice的平均提升分别为3.86%、1.38%和1.39%,Patcher的相应性能提升分别为0.66%、1.07%和1.15%。相比之下,在CNN-Transformer混合方法中,如前所述,cnn在训练时对transformer更具优势。尽管如此,通过ConvFormer重新平衡cnn和transformer可以建立更好的长期依赖关系,从而实现一致的性能改进。
与SOTA方法的比较
与最先进的解决注意力崩溃的方法相比,定量结果总结在表1中。一般来说,由于训练数据相对有限,现有的自然图像处理方法不适合医学图像分割,导致在不同主干和数据集上的性能不稳定。相比之下,ConvFormer始终优于这些方法,并为跨数据集的各种主干网带来稳定的性能改进,证明了ConvFormer作为即插即用模块的出色通用性。
自我注意矩阵的可视化
为了定性地评估ConvFormer在解决注意力崩溃和建立有效的远程依赖方面的有效性,我们将有和没有ConvFormer的自注意矩阵可视化,如图3所示。通过引入conformer,有效地缓解了注意力崩溃。与基线的自注意矩阵相比,ConvFormer学习到的矩阵更加多样化。具体来说,每个像素的交互范围是可扩展的,局部保留较小,全局接受域较大。此外,依赖性不再像ViT那样被约束为正,更符合卷积核。
图3所示。基于基线w/和w/o ConvFormer的自注意矩阵可视化。
消融研究 如3.2节所述,a用于控制ConvFormer中的感受场倾向,a越大,说明ConvFormer包含的感受场越大。为了验证这一点,我们对훼进行了消融研究,如表2所示。一般来说,使用较大的a不一定会带来更多的性能改进,这与我们的观察一致,即不是每个像素都需要全局信息进行分割。
5 结论
在本文中,我们将变压器构建为一个核可扩展的卷积来解决注意力崩溃问题,并为有效的医学图像分割构建不同的远程依赖关系。具体来说,它由池化、cnn风格的自关注(CSA)和卷积前馈网络(CFFN)组成。池化模块首先用于提取局部性细节,同时通过对输入进行下采样来减少后续CSA模块的计算成本。
然后,通过将CSA构造为核可扩展的卷积,将CSA发展为自适应的远程依赖关系,最后,使用CFFN对CSA进行改进每个像素的特征。在3个数据集的5个最先进的基线上的实验结果表明,ConvFormer的性能突出,稳定地超过了3个数据集的基线和比较方法。