Medical Transformer:门控轴向注意力用于医学图像分割
Abstract
过去十年间,深度卷积神经网络广泛应用于医学图像分割领域并取得了优异的性能;但是鉴于卷积结构自身的局限性,无法有效的捕获长程依赖关系。
近期提出的Transformer借助self-attention机制可以有效捕获长程依赖关系,获得更丰富的特征表述。因此也启发我们探究基于Transformer的网络结构用于医学图像分割。目前大多数研究需要Transformer在大规模数据集上进行预训练,但是医学图像与自然图像相比,数据集规模相对较小,因此使得Transformer应用于医学图像分割存在挑战。
为此本文提出基于门控的轴向注意力模型,为SA模块引入了一个额外的控制机制。此外为了有效的训练医学图像,本文还提出了一种叫做LoGo的训练策略,通过整张图和patch分别学习全局和局部特征。本文提出的MedT模型在三种不同种类的医学图像分割数据集上进行了测试,均取得了优于卷积神经网络的分割结果。
Section I Introduction
开发自动、准确、鲁邦的医学图像分割方法一直是医学成像的主要任务之一,因为这对计算机辅助诊断等至关重要,从医学图像中分割器官、组织可以辅助医生进行临床诊断并制定治疗计划。目前主流方法是基于卷积神经网络进行搭建,如UNet,V-Net,3D-UNet,Res-UNet,Dense-UNet,Y-Net,UNet+++,KiU-Net,这些网络在诸多具有挑战性的任务上取得了优异的性能,验证了ConvNet在特征提取进行器官/组织分割的有效性。
ConvNets是目前分割方法的主流,但是它们缺乏建模长程依赖的能力;具体地说是每一次卷积操作其感受野只是图像的局部像素,而不是全局信息;也有的研究通过图像金字塔、空洞卷积、注意力机制等进行长程建模。但是在如何建模长程依赖关系这方面仍然有继续改进的空间,因为前任的工作并不是完全聚焦于医学图像分割领域。
为了更好的理解为什么长程依赖关系对医学图像很重要,本文可视化了一个新生儿的超声扫描图像,详情参见Fig 1.为了进行有效的分割,网络需要识别哪些像素是背景,哪些属于mask哪些属于前景物理。鉴于图像的背景十分发散,因此需要学习背景像素之间的远程依赖关系,这样才能减少假阳性的判断(假阳就是将背景误认为前景);而当前景mask比较大时,则需要学习mask像素之间的远程依赖关系,才能做出正确的预测。与GT(Fig 1(e)相比)Fig 1(b)©中都错误的将背景分割成了brain区域,而基于Transformer的MedT(d)没有发生误分类,这就突出了长程依赖关系的重要性。
在诸多NLP任务中,Transformer已经被证实可以有效编码长程依赖关系,主要是基于self-attention机制可以有效的计算输入序列之间的关联。近期Transformer也被用于CV领域,如Axial-DeepLab使用一个轴向注意力模块,将二维SA分解为两个1D的SA计算,并且引入对位置敏感的轴向注意机制进行分割。SETR则是使用Transformer作为编码器,ConvNets作为解码器,以此组建了一个强大的分割器。而在医学图像分割中Transformer还没有做太多的探索,近期的工作主要聚焦于如何基于注意力机制来提升性能,但是这些网络仍然以CNN作为编码和解码器的主要模块。
本文观察到,Transformer需要在大规模数据集上预训练才能得到较好的性能,但是医学图像其数据集规模并不如自然图像规模那么庞大,能用于预训练的医学图像资源十分有限;而医学图像打标签也十分费时费力,需要很强的专家只是。而没有充足的专家信息使得学习十分困难,无法有效的学习位置信息、对图像进行编码等。
因此本文提出一种轴向的且对位置敏感的注意力机制,引入4个门来控制图像信息和位置信息的流动;而这些门控参数都是可以学习的,使得MedT可以用于任何规模的数据集。
根据数据集的大小,这些门会决定学习特定信息需要多少训练数据;此外本文还提出了一种局部-全局的训练策略(LoGo),即分别使用一个较浅的全局分支和一个较深的局部分支分别提取图像信息,这种方法使得我们不仅可以关注局部细节信息也可以关注全局信息,本文将这种基于轴向注意力的对位置敏感的Transformer模型称之为MedT。
本文工作总结如下:
(1)提出一种基于门控机制的对位置敏感的轴向注意力计算方法,即使在小规模数据集上也有较好的性能;
(2)引入LoGo这种有效的Transformer训练策略;
(3)基于上述两点提出的MedT框架专门用于医学图像分割任务;
(4)在三种不同数据集上的分割结果显示MedT性能显著优于传统的卷积网络或基于注意力机制的网络。
Section II Medical Transformer(MedT)
Part 1 Self-Attention Overview
对于CinxHxW的输入图像,self-attention的计算为:
其中Wq,Wk,Wv是需要学习的参数。但是原始SA的计算会随着输入图像的尺寸不断增大,因此不适合用于计算机视觉任务中处理图像;此外与卷积不同,SA的计算无法利用任何位置信息,而位置信息在计算机视觉任务中对于感知对象结构也十分有用。
Axial-Attention
为了降低注意力的计算复杂度,本文将SA的计算分解为两个SA模块,第一个模块沿feature map高度轴进行计算,第二个模块沿着feature map宽度轴进行计算。同时也通过position bias嵌入位置信息,通常进行的是相对位置嵌入,这些位置编码通常可以通过训练来学习,已经有研究证明可以编码图像的空间结构。
而本文参考Axial DeepLab的做法,对所有的q,k,v均进行位置编码。因此对于给定的输入,SA的计算公式为:
这只是描述了沿宽度轴的SA计算,沿高度轴的计算也类似。
Part 2 Gated Axial-Attention
本文探讨了使用轴向注意力的好处,可以有效的捕获非局部的上下文信息并且计算更高效,同时可以有效的嵌入位置信息,捕获特征映射之间的长程依赖关系。但是这些模型都是在大型分割数据集上训练的结果,可以同分的学习q,k,v的各种差异。本文则关注于小规模数据集上的应用,因为这是医学图像分割的常见应用场景。
在这种情景下,往往很难学习位置信息,从而也就无法精确的编码长程依赖关系。此时在相对位置信息不够精确的前提下,将位置信息嵌入到各自的q,k,v中会导致精度下降,因此本文提出一种改进的轴向注意力模块(modified axial-attention block)可以控制位置偏差的嵌入位置、施加的影响。
比如添加门控机制后,宽度轴上的SA计算公式变为:
Gq,Gk,Gv1,Gv2是门控参数,都是科学系的,通过这些门控参数控制相对位置信息对非局部上下文信息的影响;通常当相对位置信息比较准确时门控参数会赋予更高的权重。
Part 3 Local-Global Training
显然,如果输入的是patch,transformer计算会很快,但仅训练patch不足以完成医学图像的分割任务,因为patch限制了学习patch之间的像素和语义关联。
为了更好的理解整张图,本文在网络中使用了两个分支,一个全局分支和一个局部分支,全局分支处理的是整张图像,局部分支处理的是图像patch。
全局分支中减少了门控Transformer层的数量,因为通过实验观察到transformer仅依靠前面几层就足以建模长程依赖关系;本地分支则使用的是I/4大小的16个patch作为输入,经过transformer层处理的输出会重采样获得最终的输出;这两个分支的处理结果会进行add然后经过1x1卷积层 获得最终的分割结果。
这种训练策略有效的捕获了高级全局信息和浅层次细节信息。
因此MedT使用门控轴向注意力层作为基础模块,并采用LoGo策略进行训练,详情参见Fig 2(a)。
Fig 2(a)展示了MedT的整体架构,Fig 2(b)展示的是MedT中的轴向transformer层,Fig 2©则是门控轴向注意力层,是MedT中的基础组件.
Section III Exoeriments
Dataset
Brain anatomy Segmentation
Gland数据集
MoNuSeg数据集
损失函数:交叉熵损失函数
对比网络:
FCN,UNet,UNet++,Res-UNet
本文还进行了消融实验测试各部分的作用
Results
Table 1展示了不同数据集上不同网络的分割结果,评价指标有F1分数和IoU,对比网络包含CNNbased的FCN,UNet,UNet++,Res-UNet;注意力机制的Axial Attention UNet;以及本文的MedT
对于脑部分割,可以看到Attention网路性能比CNN网络性能好;但是对于胰腺和细胞分割因为数据集规模较小,attention机制就有一定局限性。而本文的门控机制可以减轻这一限制,结合LoGo训练策略取得了SOTA结果,在三类数据集上的性能都达到了SOTA。
Fig 3可视化了三类数据集上一些实际的分割结果,可以看到MedT确实有效的捕获了长程依赖关系,比如第二行的白点,所有卷积网络都没能分割出来;而由于注意力机制编码了长程依赖关系就能够分割出这一部分;而在第一行和第四行CNN和注意力机制均出现了误分割,而MedT的门控机制可以很好的根据像素之间的依赖关系决定注意力的权重,从而获得更精确的分割结果。
Section IV Conclusion
本文探究了基于Tansformer的医学图像分割模型,并提出了基于轴向注意力的MedT分割框架,它可以通过门控机制决定注意力的权重;此外本文还提出了一种LoGo训练策略可以按照不同分支充分学习全局和局部信息。
MedT的优势在于无需在大规模数据集上进行预训练,本文在三类分割数据集上进行了广泛的实验,均达到了SOTA。
Appendix
Part 1 Dataset
Brain US Dataset
颅内出血是导致早产儿脑损伤的主要病因,主要诊断方法需要进行露骨成像。
Brain US Dataset采集了20名早产儿的1629张颅内扫描图像,training:testing = 1629:329 图像大小reshape到128128
GLAS Dataset
Gland数据集包含165张胰腺HE染色图像,training:testing = 85:80.可以看到数据集规模较小 依旧reshape到128128
MoNuSeg Dataset
包含40x放大的细胞HE染色图像,training:testing = 30:14 图像大小reshape到512*512
Part 2 MedT
MedT使用LoGo训练策略,包含两个分支:全局分支和局部分支。而这两个分支的输入是卷积模块输出的feature map,卷积块包含3层conv,每次卷积之后经过BN+ReLU。 两个分支的Encoder使用的就是transformer layer,Encoder使用两个transformer block分别沿高度轴和宽度轴计算SA,而每个MHSA模块则如Fig2©所示包含8个门控参数;MHSA的输出经过1x1卷积处理后送入Decoder;两个分支均包含5个encoder block和5个decoder block。
Part 3 Training Details
优化器 Adam LR = 0.001 训练400epochs 训练卡 Nvidia quadro8000
Part 4 Ablation Study
首先搭建UNet,随后添加残差连接组成Res-UNet。
Axial UNet则是将所有卷积层替换为axial attention层;Gated Axial UNet是将轴向注意替换为带门控的轴向注意层。
训练时也对比了仅使用全局分支或局部分支的效果。
Table 1是消融实验的对比结果。
Table 2展示的是不同网络的参数量及对应精度。其中mod都是轻量级版本,通过对比主要是为了说明即使有的基准网络具备更多网络参数,性能也没有超过MedT。
Part 5 Results
Fig 1可视化了一些分割结果,可以看到基于CNN的方法总是会有误分割现象,但是本文的MedT就不会产生误分割。
Part 6 Concurrent works
TransUNet灵感源自ViT,在UNet结构中融合了Transformer,但是TransUNet依旧依赖于大规模数据集预训练的结果;
Transfuse也是近期提出用于息肉分割的网络框架,使用了并行CNN分支与Transformer分支融合进行分割。
而本文则是探究仅借助Transformer的注意力机制进行分割的可行性,同时无需任何预训练。