目录
Title:Multi-Compound Transformer for Accurate Biomedical Image Segmentation
Transformer-self-Attention(TSA)
Transformer-Cross-Attention(TCA)
Title:Multi-Compound Transformer for Accurate Biomedical Image Segmentation
摘要-Abstract
在CV领域中,Transformer被用于学习不同token之间的局部交互作用,但是之前的一些研究遗漏了对不同像素之间跨尺度的依赖性,不同标签之间对应的语义关系,特征表征与对应语义嵌入的一致性。这些对精确的医学图像分割十分重要。因此本文提出了一种多类别复合Transformer来解决上述问题,本文的网络框架将丰富的特征学习和语义结构挖掘集成。
具体来说,MCTrans将多尺度的卷积特征表示为token序列,执行同尺度或者跨尺度的注意力计算,而不像之前计算的都是单尺度的注意力。此外还在建模语义关系是引入一个可训练的代理嵌入器,通过自注意力(TSA)来建模语义关系,通过交叉注意力(TCA)来进行特征增强。
Introduction
解决的问题
1)引入了TSA模块,通过自我注意机制实现跨尺度的像素级上下文建模,从而对不同尺度进行更全面的特征增强。
2)开发了TCA模块,通过引入代理嵌入来自动学习不同语义类别的语义对应关系。使用这种代理嵌入,通过交叉注意机制与特征表征进行交互。通过对更新后的代理嵌入添加辅助损失可以有效提高同一类别的特征相关性和不同类别之间的特征可辨别性。
论文的主要贡献
(1)提出的MCTrans可以有效的建立不同尺度之间的依赖关系和特征之间的关联从而进行精确的医学图像分割;
(2)提出了一种新的可训练的proxy embedding,通过SA和CA来建立类别之间的依赖性和增强特征表述;
(3)本文提出的MCTrans可以方便的集合进UNet网络类型,并且在6个具有挑战的数据集上均达到了SOTA。实验结果证明了本文提出的网络结构的有效性。
方法-Method
如下图所示,MCTrans的结构图如下所示:
就是在经典的UNet的编码器解码器架构之间引入了MCTransformer,它由TSA和TCA模块构成。引入前者来对多个特征之间的上下文信息进行编码,产生丰富而一致的像素级上下文。而后者引入了可学习嵌入的语义关系建模,并进一步增强了特征表示。
方法概括
实际上给定一个图像l,采用深度CNN来提取不同尺度的多层次特征{},对于层级i,特征以P*P的Patch大小展开,其中P设置为1,即第i个特征图每个位置被视为patch,总共有个patch。接下来将不同层的分割patch传递到具有相同输出特征维度的单个投影头(即1*1卷积层),并获得嵌入的Token,并且为了补偿缺失的位置信息,位置嵌入被补充到token当中,以提供关于特征在序列中相对或绝对位置信息,用公式可以表述为T=T+。接下来,我们将Token输入TSA模块,用于多尺度上下文建模。输出增强的Token进一步通过TCA模块,并与代理嵌入进行交互,代理嵌入。M是数据集的类别数。最后,我们将编码的token折叠回金字塔特征,并以自下而上的方式合并,以获得用于预测的最终特征图。
Transformer-self-Attention(TSA)
TSA的结构如图
这个模块的输入是1D Token,主要采集不同尺度特征之间的上下文依赖关系,在其具体结构中我们可以看到包含MSA层+AddNorm层+FFN层以及Add Norm层,并使用了残差连接,FFN中包含两次线性层和Relu做非线性激活,因此对于第l层自我注意的输入是从输入计算的三元组,如:
三个W是不同线性投影头的参数矩阵,第二个式子是自注意力机制的公式表示
MSA是一个具有h个独立的SA操作的扩展,并将他们的连接输出投影为式子(3),是输出线性投影头的一个参数。整个计算过程可以表示为多头注意力机制部分和前馈神经网络部分的二者求和。并且为了简单起见式子中忽略了LN
Transformer-Cross-Attention(TCA)
这个模块的作用是用来增强特征表述的,其结构如下图所示:
在这一模块当中,除了增强的Token ,还提出了一种可学习的代理嵌入来学习类别之间的全局语义关系(即类内/类间)。与TSA一样,结构中包含注意力层,LayerNorm层,AddNorm层,但是在网络结构中使用了两个多头注意力层。对于第j层,代理嵌入由各种线性投影头进行转换生成第一个MSA块的输入(q,k,v)。在这里,MSA块的自我注意机制与每一对类别进行连接和交互,从而建模不同标签的语义对应关系。接下来,学习到的代理嵌入通过在另一个MSA块中的交叉注意,提取并与输入标记的特征进行交互。其中查询输入是代理嵌入,键,值输入是Token。
通过交叉注意,Token的特征与学习到的全局语义关系进行融合,全面的提高了类内一致性和特征表示的类间可辨别性,产生了最新的代理嵌入。两个MSA块的计算等于
此外我们还引入了辅助损失函数来促进代理嵌入学习,特别的输出,对TCA模块最后一层的证明进一步传递给线性投影头,并产生多类预测,在GT分割掩模的基础上,我们找到了唯一的元素来计算分类标签的监督。这样,驱动代理嵌入学习适当的语义关系,有助于提高同一类别的特征相关性和不同类别之间的特征可辨别性。最后将编码的标记折叠回二维特并附加不涉及的特征来形成金字塔特征{X0,X1,X'2,X'3,X'4}.然后我们逐渐对其进行上采样,以获得最终的分割结果图。
总结
本文提出了一种强大的基于Transformer的医学图像分割网络,本文的方法通过强大的注意机制结合了丰富的上下文建模和语义关系挖掘,有效的解决了跨尺度依赖性,不同类别的语义对应关系等问题。