代码:BATFormer
发表时间:2023
发表期刊:IEEE TMI
这篇论文介绍了一种新的Transformer架构,称为BATFormer(Boundary-Aware Lightweight Transformer),旨在改进医学图像分割的效率和效果。
目标
- 解决现有问题:传统的卷积神经网络(CNNs)由于感受野有限,在处理需要长距离依赖关系的任务时表现不佳。而现有的Transformer虽然能解决这一问题,但其计算复杂度高,并且在进行全局表示学习时采用了固定的窗口分区策略,这在医学图像分割任务中可能导致边界失真。
创新点
- 跨尺度全局Transformer (CGT) 模块:为了利用Transformer在建立长距离依赖关系上的优势,同时减少计算复杂度,作者提出了一个CGT模块。该模块可以联合使用多个小规模特征图来提取更丰富的全局特征,从而降低计算成本。
- 边界感知局部Transformer (BLT) 模块:考虑到形状建模在医学图像分割中的重要性,设计了BLT模块。与传统Transformer采用固定窗口分区不同的是,BLT采用基于熵的自适应窗口分区方案,以减少计算复杂度的同时保持物体的形状。
方法
A. 概述
BATFormer的整体架构如图2所示,包括一个轻量级U形主干网络和两个主要组件,旨在解决前述的问题:
- 跨尺度全局Transformer (CGT)
- 边界感知局部Transformer (BLT)
这两个组件分别针对医学图像分割中遇到的计算复杂度高和边界失真问题进行了优化。
1. 跨尺度全局Transformer (CGT)
-
计算成本与输入序列长度的关系:在大规模表示(如F1和F2)上构建全局Transformer的成本非常高,因为计算复杂度与输入序列长度的平方成正比。
-
深度语义特征的需求:全局依赖关系主要用于构建深层语义特征。因此,CGT仅接收多个小规模(即高层次)特征图(即F3、F4和F5),通过跨尺度自注意力机制,CGT能够在较低计算复杂度的情况下提取更丰富的语义特征,用于全局表示学习。这反过来补充了CNN模块,实现了联合的局部和全局特征提取与融合。
2. 边界感知局部Transformer (BLT)
- 灵活窗口分区方案:为了解决边界失真的问题,提出了基于熵计算的灵活窗口分区方案来定位与边界相关的窗口。通过这种方法,BLT可以过滤掉非边界窗口,更专注于探索长距离特征以进行边界检测,并且降低了计算复杂度。
B.轻量级U形主干网络
BATFormer的轻量级U形主干网络是整个架构的基础,它主要由四个下采样块、四个上采样块和五个跳跃连接组成。这种设计借鉴了经典的U-Net架构,但进行了优化以减轻计算负担,同时保持高效的特征提取能力。
主干网络结构
-
输入处理:
- 每个输入图像首先经过两次3x3卷积操作。
-
下采样阶段(Encoder):
- 经过初始卷积后,图像被送入四个连续的下采样块中。
- 每个下采样块包含一个2x2的最大池化层(步长为2),用于减少空间维度,以及两个3x3的卷积层来提取特征。
- 下采样的通道数依次设置为 2 C , 4 C , 8 C , 8 C 2C, 4C, 8C, 8C 2C,4C,8C,8C,其中 C = 16 C=16 C=16。
-
跳跃连接:
- 相同尺度的下采样块和上采样块通过跳跃连接相连,这有助于保留和传递局部信息,防止在下采样过程中丢失重要细节。
-
上采样阶段(Decoder):
- 四个上采样块负责将特征图逐步恢复到原始尺寸。
- 每个上采样块包括双线性插值(用于放大图像)、特征图拼接(与对应的下采样块通过跳跃连接传来的特征图结合),以及两个3x3的卷积层。
- 上采样的通道数依次设置为 4 C , 2 C , C , C 4C, 2C, C, C 4C,2C,C,C,同样 C = 16 C=16 C=16。
C. 跨尺度全局Transformer (CGT)
CGT旨在通过两个跨尺度注意力模块和一个前馈网络(FFN)来构建三个小尺度特征图之间的全局依赖关系。这三个特征图由主干网络生成,分别是 F 3 ∈ R 4 C × H 4 × W 4 F3 \in \mathbb{R}^{4C \times \frac{H}{4} \times \frac{W}{4}} F3∈R4C×4H×4W、 F 4 ∈ R 8 C × H 8 × W 8 F4 \in \mathbb{R}^{8C \times \frac{H}{8} \times \frac{W}{8}} F4∈R8C×8H×8W 和 F 5 ∈ R 8 C × H 16 × W 16 F5 \in \mathbb{R}^{8C \times \frac{H}{16} \times \frac{W}{16}} F5∈R8C×16H×16W,其中 C C C 是第一个分支头的通道数, ( H , W ) (H, W) (H,W) 是输入图像的分辨率。
跨尺度注意力模块
在跨尺度注意力模块中:
-
查询生成:最高分辨率的特征图 F 3 F3 F3 被投影成两个查询 Q 3 , 4 ∈ R H W 16 × d Q3,4 \in \mathbb{R}^{\frac{HW}{16} \times d} Q3,4∈R16HW×d和 Q 3 , 5 ∈ R H W 16 × d Q3,5 \in \mathbb{R}^{\frac{HW}{16} \times d} Q3,5∈R16HW×d,其中 d d d 是Transformer模块的维度。
-
键值对生成:较低分辨率的特征图 F 4 F4 F4 和 F 5 F5 F5 分别被投影成键值对 { K 4 , V 4 } ∈ R H W 64 × d \{K4, V4\} \in \mathbb{R}^{\frac{HW}{64} \times d} { K4,V4}∈R64HW×d和 { K 5 , V 5 } ∈ R H W 256 × d \{K5, V5\} \in \mathbb{R}^{\frac{HW}{256} \times d} { K5,V5}∈R256HW×d。
-
跨尺度注意力计算:跨尺度注意力公式如下:
F i c a ( Q 3 , i , K i , V i ) = softmax ( Q 3 , i K i T d ) V i F_i^{ca}(Q_{3,i}, K_i, V_i) = \text{softmax}\left(\frac{Q_{3,i}K_i^T}{\sqrt{d}}\right)V_i Fica(Q3,i,Ki,Vi)=softmax(dQ3,iKiT)Vi
其中 i i i分别为4和5,对应于 F 4 F4 F4 和 F 5 F5 F5。
与标准自注意力机制不同的是,这里的键 K K K 和值 V V V 来自其他两个较小尺度的特征图,而不是来自同一输入。这样做有两个主要优点:
-
降低计算复杂度:由于 K K K和 V V V 的序列长度较短,计算复杂度可以减少 2 2 2^2 22 或 2 4 2^4