代码:BATFormer
发表时间:2023
发表期刊:IEEE TMI
这篇论文介绍了一种新的Transformer架构,称为BATFormer(Boundary-Aware Lightweight Transformer),旨在改进医学图像分割的效率和效果。
目标
- 解决现有问题:传统的卷积神经网络(CNNs)由于感受野有限,在处理需要长距离依赖关系的任务时表现不佳。而现有的Transformer虽然能解决这一问题,但其计算复杂度高,并且在进行全局表示学习时采用了固定的窗口分区策略,这在医学图像分割任务中可能导致边界失真。
创新点
- 跨尺度全局Transformer (CGT) 模块:为了利用Transformer在建立长距离依赖关系上的优势,同时减少计算复杂度,作者提出了一个CGT模块。该模块可以联合使用多个小规模特征图来提取更丰富的全局特征,从而降低计算成本。
- 边界感知局部Transformer (BLT) 模块:考虑到形状建模在医学图像分割中的重要性,设计了BLT模块。与传统Transformer采用固定窗口分区不同的是,BLT采用基于熵的自适应窗口分区方案,以减少计算复杂度的同时保持物体的形状。
方法
A. 概述
BATFormer的整体架构如图2所示,包括一个轻量级U形主干网络和两个主要组件,旨在解决前述的问题:
- 跨尺度全局Transformer (CGT)
- 边界感知局部Transformer (BLT)
这两个组件分别针对医学图像分割中遇到的计算复杂度高和边界失真问题进行了优化。
1. 跨尺度全局Transformer (CGT)
-
计算成本与输入序列长度的关系:在大规模表示(如F1和F2)上构建全局Transformer的成本非常高,因为计算复杂度与输入序列长度的平方成正比。
-
深度语义特征的需求:全局依赖关系主要用于构建深层语义特征。因此,CGT仅接收多个小规模(即高层次)特征图(即F3、F4和F5),通过跨尺度自注意力机制,CGT能够在较低计算复杂度的情况下提取更丰富的语义特征,用于全局表示学习。这反过来补充了CNN模块,实现了联合的局部和全局特征提取与融合。
2. 边界感知局部Transformer (BLT)
- 灵活窗口分区方案:为了解决边界失真的问题,提出了基于熵计算的灵活窗口分区方案来定位与边界相关的窗口。通过这种方法,BLT可以过滤掉非边界窗口,更专注于探索长距离特征以进行边界检测,并且降低了计算复杂度。
B.轻量级U形主干网络
BATFormer的轻量级U形主干网络是整个架构的基础,它主要由四个下采样块、四个上采样块和五个跳跃连接组成。这种设计借鉴了经典的U-Net架构,但进行了优化以减轻计算负担,同时保持高效的特征提取能力。
主干网络结构
-
输入处理:
- 每个输入图像首先经过两次3x3卷积操作。
-
下采样阶段(Encoder):
- 经过初始卷积后,图像被送入四个连续的下采样块中。
- 每个下采样块包含一个2x2的最大池化层(步长为2),用于减少空间维度,以及两个3x3的卷积层来提取特征。
- 下采样的通道数依次设置为 2 C , 4 C , 8 C , 8 C 2C, 4C, 8C, 8C 2C,4C,8C,8C,其中 C = 16 C=16 C=16。
-
跳跃连接:
- 相同尺度的下采样块和上采样块通过跳跃连接相连,这有助于保留和传递局部信息,防止在下采样过程中丢失重要细节。
-
上采样阶段(Decoder):
- 四个上采样块负责将特征图逐步恢复到原始尺寸。
- 每个上采样块包括双线性插值(用于放大图像)、特征图拼接(与对应的下采样块通过跳跃连接传来的特征图结合),以及两个3x3的卷积层。
- 上采样的通道数依次设置为 4 C , 2 C , C , C 4C, 2C, C, C 4C,2C,C,C,同样 C = 16 C=16 C=16。
C. 跨尺度全局Transformer (CGT)
CGT旨在通过两个跨尺度注意力模块和一个前馈网络(FFN)来构建三个小尺度特征图之间的全局依赖关系。这三个特征图由主干网络生成,分别是 F 3 ∈ R 4 C × H 4 × W 4 F3 \in \mathbb{R}^{4C \times \frac{H}{4} \times \frac{W}{4}} F3∈R4C×4H×4W、 F 4 ∈ R 8 C × H 8 × W 8 F4 \in \mathbb{R}^{8C \times \frac{H}{8} \times \frac{W}{8}} F4∈R8C×8H×8W 和 F 5 ∈ R 8 C × H 16 × W 16 F5 \in \mathbb{R}^{8C \times \frac{H}{16} \times \frac{W}{16}} F5∈R8C×16H×16W,其中 C C C 是第一个分支头的通道数, ( H , W ) (H, W) (H,W) 是输入图像的分辨率。
跨尺度注意力模块
在跨尺度注意力模块中:
-
查询生成:最高分辨率的特征图 F 3 F3 F3 被投影成两个查询 Q 3 , 4 ∈ R H W 16 × d Q3,4 \in \mathbb{R}^{\frac{HW}{16} \times d} Q3,4∈R16HW×d和 Q 3 , 5 ∈ R H W 16 × d Q3,5 \in \mathbb{R}^{\frac{HW}{16} \times d} Q3,5∈R16HW×d,其中 d d d 是Transformer模块的维度。
-
键值对生成:较低分辨率的特征图 F 4 F4 F4 和 F 5 F5 F5 分别被投影成键值对 { K 4 , V 4 } ∈ R H W 64 × d \{K4, V4\} \in \mathbb{R}^{\frac{HW}{64} \times d} {K4,V4}∈R64HW×d和 { K 5 , V 5 } ∈ R H W 256 × d \{K5, V5\} \in \mathbb{R}^{\frac{HW}{256} \times d} {K5,V5}∈R256HW×d。
-
跨尺度注意力计算:跨尺度注意力公式如下:
F i c a ( Q 3 , i , K i , V i ) = softmax ( Q 3 , i K i T d ) V i F_i^{ca}(Q_{3,i}, K_i, V_i) = \text{softmax}\left(\frac{Q_{3,i}K_i^T}{\sqrt{d}}\right)V_i Fica(Q3,i,Ki,Vi)=softmax(dQ3,iKiT)Vi
其中 i i i分别为4和5,对应于 F 4 F4 F4 和 F 5 F5 F5。
与标准自注意力机制不同的是,这里的键 K K K 和值 V V V 来自其他两个较小尺度的特征图,而不是来自同一输入。这样做有两个主要优点:
-
降低计算复杂度:由于 K K K和 V V V 的序列长度较短,计算复杂度可以减少 2 2 2^2 22 或 2 4 2^4 24 倍。
-
引入多样化特征:不同的 K K K 和 V V V 组对应于不同尺度的感受野和语义信息,从而增强了依赖关系的多样性。
特征融合与前馈网络(FFN)
在送入FFN之前,两组跨尺度注意力的结果会被结合并精炼:
F c a = ( F 4 c a , 1 ⊕ ⋯ ⊕ F 4 c a , g ⊕ F 5 c a , 1 ⋯ ⊕ F 5 c a , g ) ⋅ W c a + F 3 , F^{ca} = (F_4^{ca,1} \oplus \cdots \oplus F_4^{ca,g} \oplus F_5^{ca,1} \cdots \oplus F_5^{ca,g}) \cdot W^{ca} + F3, Fca=(F4ca,1⊕⋯⊕F4ca,g⊕F5ca,1⋯⊕F5ca,g)⋅Wca+F3,
其中 g g g是CGT中预定义的Transformer头的数量, W c a ∈ R 2 g d × d W^{ca} \in \mathbb{R}^{2gd \times d} Wca∈R2gd×d是用于组合的可学习投影矩阵, ⊕ \oplus ⊕ 表示拼接操作, + + + 表示通过元素级加法实现的残差连接。
经过跨尺度注意力后的FFN进一步处理以获得CGT的最终输出:
F C G T = ( max ( 0 , F c a W c 1 + b c 1 ) ⋅ W c 2 + b c 2 ) + F c a , F^{CGT} = (\max(0, F^{ca}W_c1 + b_c1) \cdot W_c2 + b_c2) + F^{ca}, FCGT=(max(0,FcaWc1+bc1)⋅Wc2+bc2)+Fca,
其中 W c 1 ∈ R d × 4 d W_c1 \in \mathbb{R}^{d \times 4d} Wc1∈Rd×4d和 W c 2 ∈ R 4 d × d W_c2 \in \mathbb{R}^{4d \times d} Wc2∈R4d×d是可学习的投影矩阵, b c 1 b_c1 bc1 和 b c 2 ∈ R b_c2 \in \mathbb{R} bc2∈R 是偏置项。
CGT的主要特点:
-
多尺度特征融合:通过跨尺度注意力机制,CGT能够有效地融合来自不同尺度特征图的信息,建立更丰富的全局依赖结构。
-
计算效率提升:通过使用较小尺度的特征图作为键值对,显著降低了计算复杂度,同时保持了有效的特征交互。
-
增强的全局表示能力:跨尺度注意力不仅减少了计算量,还增强了模型对全局上下文的理解,有助于提高分割任务中的表现。
D. 边界感知局部Transformer (BLT)
BLT的设计旨在克服现有Transformer中刚性窗口划分带来的边界信息丢失问题。它包含三个主要阶段:动态边界感知窗口生成、边界附近的自注意力机制以及特征蒸馏的前馈网络(FFN)。这些设计使得BLT能够在保持计算效率的同时,更好地捕捉和处理医学图像中的边界信息。
1. 动态边界感知窗口生成
-
初始窗口集合:首先,在特征图 F 2 ∈ R 2 C × H 2 × W 2 F2 \in \mathbb{R}^{2C \times \frac{H}{2} \times \frac{W}{2}} F2∈R2C×2H×2W 上均匀密集地平铺窗口,收集所有可能的窗口位置,形成初始窗口集合 ({w})。
-
熵计算:为了确定每个窗口是否位于边界附近,首先需要计算由CGT产生的概率图 P C G T ∈ R c × H 4 × W 4 P^{CGT} \in \mathbb{R}^{c \times \frac{H}{4} \times \frac{W}{4}} PCGT∈Rc×4H×4W中每个位置 ( m , n ) (m, n) (m,n) 的熵:
C p ( m , n ) = − 1 log 2 c ∑ l = 1 c P C G T ( l , m , n ) log 2 P C G T ( l , m , n ) C_p(m, n) = -\frac{1}{\log_2 c} \sum_{l=1}^{c} P^{CGT}(l, m, n) \log_2 P^{CGT}(l, m, n) Cp(m,n)=−log2c1∑l=1cPCGT(l,m,n)log2PCGT(l,m,n)
其中 c c c 是类别数量。 -
窗口评分:然后,根据每个窗口内的熵值计算窗口得分 C w ( x , y ) C_w(x, y) Cw(x,y),其中 ( x , y ) (x, y) (x,y) 是窗口的左上角坐标, ( h , w ) (h, w) (h,w) 是窗口大小:
C w ( x , y ) = 1 ⌊ h 2 ⌋ ⌊ w 2 ⌋ ∑ m = 0 ⌊ h 2 ⌋ ∑ n = 0 ⌊ w 2 ⌋ C p ( ⌊ x 2 ⌋ + m , ⌊ y 2 ⌋ + n ) . C_w(x, y) = \frac{1}{\left\lfloor \frac{h}{2} \right\rfloor \left\lfloor \frac{w}{2} \right\rfloor} \sum_{m=0}^{\left\lfloor \frac{h}{2} \right\rfloor} \sum_{n=0}^{\left\lfloor \frac{w}{2} \right\rfloor} C_p\left(\left\lfloor \frac{x}{2} \right\rfloor + m, \left\lfloor \frac{y}{2} \right\rfloor + n\right). Cw(x,y)=⌊2h⌋⌊2w⌋1∑m=0⌊2h⌋∑n=0⌊2w⌋Cp(⌊2x⌋+m,⌊2y⌋+n). -
非极大值抑制(NMS):对初始窗口集合 {w}$ 根据上述评分进行非极大值抑制,去除冗余窗口,并通过RoIAlign与 F 2 F2 F2对齐,最终得到边界附近的窗口集合 { w ∗ } \{w*\} {w∗} 和对应的特征集合 ${f*}。
2. 边界附近的自注意力机制
-
多头自注意力:在每个边界附近的窗口内执行多头自注意力操作,以建立长距离依赖关系:
F s a ( f j ∗ ) = softmax ( f j ∗ T E q f j ∗ E k d ) f j ∗ E v , F^{sa}(f_j^*) = \text{softmax}\left(\frac{f_j^{*T} E_q f_j^* E_k}{\sqrt{d}}\right) f_j^* E_v, Fsa(fj∗)=softmax(dfj∗TEqfj∗Ek)fj∗Ev,
其中 E q ∈ R 2 C × d E_q \in \mathbb{R}^{2C \times d} Eq∈R2C×d, E k ∈ R 2 C × d E_k \in \mathbb{R}^{2C \times d} Ek∈R2C×d, E v ∈ R 2 C × d E_v \in \mathbb{R}^{2C \times d} Ev∈R2C×d 是可学习的投影矩阵, W s a ∈ R g d × d W^{sa} \in \mathbb{R}^{gd \times d} Wsa∈Rgd×d 用于组合不同头部的结果。 -
特征融合:将多头自注意力结果结合并精炼:
F s a = ( F s a , 1 ⊕ ⋯ ⊕ F s a , g ) ⋅ W s a + F 2 , F^{sa} = (F^{sa,1} \oplus \cdots \oplus F^{sa,g}) \cdot W^{sa} + F2, Fsa=(Fsa,1⊕⋯⊕Fsa,g)⋅Wsa+F2,
3. 特征蒸馏的前馈网络(FFN)
- 最终输出:经过FFN进一步处理,获得BLT的最终输出:
F B L T = ( max ( 0 , F s a W s 1 + b s 1 ) ⋅ W s 2 + b s 2 ) + F s a , F^{BLT} = (\max(0, F^{sa}W_s1 + b_s1) \cdot W_s2 + b_s2) + F^{sa}, FBLT=(max(0,FsaWs1+bs1)⋅Ws2+bs2)+Fsa,
其中 W s 1 ∈ R d × 4 d W_s1 \in \mathbb{R}^{d \times 4d} Ws1∈Rd×4d, W s 2 ∈ R 4 d × d W_s2 \in \mathbb{R}^{4d \times d} Ws2∈R4d×d 是可学习的投影矩阵, b s 1 b_s1 bs1 和 b s 2 ∈ R b_s2 \in \mathbb{R} bs2∈R 是偏置项。
计算复杂度分析
-
BLT的计算复杂度:只有边界附近的区域会通过上述边界附近的局部多头自注意力进行表示学习,因此其计算复杂度为:
Ω B L T = k ( 6 d h w C + 2 d ( h w ) 2 + d 2 h w ) , \Omega^{BLT} = k(6dhwC + 2d(hw)^2 + d^2hw), ΩBLT=k(6dhwC+2d(hw)2+d2hw),
其中 k k k 是最大窗口数。在实验中,采用 h = H 32 h = \frac{H}{32} h=32H, w = W 32 w = \frac{W}{32} w=32W,并将 k k k 设置为 α H W h w \alpha \frac{HW}{hw} αhwHW, α ∈ ( 0 , 1 ) \alpha \in (0, 1) α∈(0,1)。因此,公式8可以重写为:
Ω B L T ≈ O ( α 1 16 × 32 d ( H W ) 2 ) . \Omega^{BLT} \approx O\left(\alpha \frac{1}{16 \times 32} d(HW)^2\right). ΩBLT≈O(α16×321d(HW)2). -
与标准自注意力的比较:直接在 (F2) 上应用标准自注意力的计算复杂度为:
Ω V I T ≈ O ( 1 8 d ( H W ) 2 ) . \Omega^{VIT} \approx O\left(\frac{1}{8} d(HW)^2\right). ΩVIT≈O(81d(HW)2).
相比之下,BLT的计算复杂度最多低64倍。
BLT的主要特点
-
边界感知窗口生成:通过基于熵的评分机制和非极大值抑制,BLT能够自适应地定位边界附近的窗口,有效减少了对边界信息的破坏。
-
高效的边界识别:只对边界附近的区域应用多头自注意力,显著降低了计算复杂度,同时确保了对边界的精确建模。
-
轻量级设计:通过精心设计的计算流程和优化的参数配置,BLT在保持高效计算的同时实现了高质量的边界检测。
E. 多尺度软监督
为了应对医学图像分割中多尺度特征带来的边界不确定性问题,引入了软监督机制。这一机制旨在通过在不同尺度上提供监督信号,缓解由于下采样或上采样导致的边界信息丢失。具体来说,每个真实标签掩码 M ∈ R H × W M \in \mathbb{R}^{H \times W} M∈RH×W 会被调整到多个尺度以用于监督(例如, G B L T ∈ R c × H 2 × W 2 G^{BLT} \in \mathbb{R}^{c \times \frac{H}{2} \times \frac{W}{2}} GBLT∈Rc×2H×2W 用于训练BLT,以及 G C G T ∈ R c × H 4 × W 4 G^{CGT} \in \mathbb{R}^{c \times \frac{H}{4} \times \frac{W}{4}} GCGT∈Rc×4H×4W 用于训练CGT,同时保持原始边界的分布。
软监督的具体实现
给定类别 l l l 和下采样尺度 s s s,调整后的真值图 G l , s G_{l,s} Gl,s 的构建方法如下:
G l , s ( i , j ) = 1 2 2 ( s − 1 ) ∑ ( m , n ) ∈ O i , j ∣ M ( m , n ) = = l ∣ , G_{l,s}(i, j) = \frac{1}{2^{2(s-1)}} \sum_{(m,n) \in O_{i,j}} |M(m, n) == l|, Gl,s(i,j)=22(s−1)1∑(m,n)∈Oi,j∣M(m,n)==l∣,
其中 O i , j O_{i,j} Oi,j 表示真值图 M M M 中对应于 G l , s G_{l,s} Gl,s 中位置 ( i , j ) (i, j) (i,j) 的下采样块,定义为 M ( ( i − 1 ) 2 s − 1 + 1 : i 2 s − 1 , ( j − 1 ) 2 s − 1 + 1 : j 2 s − 1 ) M((i − 1)2^{s−1} + 1 : i2^{s−1}, (j − 1)2^{s−1} + 1 : j2^{s−1}) M((i−1)2s−1+1:i2s−1,(j−1)2s−1+1:j2s−1)。这样, G l , s G_{l,s} Gl,s 中位置 ( i , j ) (i, j) (i,j) 的概率值由下采样块内类别 l l l 出现的频率决定,并且 G l , s G_{l,s} Gl,s 中的边界分布会接近原始掩码 M M M 的边界分布。
对于给定的下采样尺度 s s s,相应地构造一组调整后的真值图 G s = { G l , s ∣ l = 1 , 2 , . . . , c } G_s = \{G_{l,s}|l = 1, 2, ..., c\} Gs={Gl,s∣l=1,2,...,c}。然后,定义相同尺度下的预测图 P s = { P l , s ∣ l = 1 , 2 , . . . , c } P_s = \{P_{l,s}|l = 1, 2, ..., c\} Ps={Pl,s∣l=1,2,...,c} 与真值图 G s G_s Gs 之间的损失函数为:
L s = ∑ l = 1 c ∑ i = 1 H 2 s − 1 ∑ j = 1 W 2 s − 1 ∣ G l , s ( i , j ) − P l , s ( i , j ) ∣ . L_s = \sum_{l=1}^{c} \sum_{i=1}^{\frac{H}{2^{s-1}}} \sum_{j=1}^{\frac{W}{2^{s-1}}} |G_{l,s}(i, j) - P_{l,s}(i, j)|. Ls=∑l=1c∑i=12s−1H∑j=12s−1W∣Gl,s(i,j)−Pl,s(i,j)∣.
设置 s = 1 , 2 , 3 s = 1, 2, 3 s=1,2,3 分别定义用于CNN、BLT和CGT的损失函数 KaTeX parse error: Can't use function '\)' in math mode at position 8: L^{CNN}\̲)̲,\(L^{BLT},和 L C G T L^{CGT} LCGT。这些特征进一步融合以生成最终预测,并通过交叉熵损失和Dice损失进行惩罚,记作 L C L^{C} LC。
总体损失函数
多尺度软监督的整体损失函数 L L L 定义为:
L = β 1 L C N N + β 2 L B L T + β 3 L C G T + β 4 L C , L = \beta_1 L^{CNN} + \beta_2 L^{BLT} + \beta_3 L^{CGT} + \beta_4 L^{C}, L=β1LCNN+β2LBLT+β3LCGT+β4LC,
其中 β 1 , β 2 , β 3 , \beta_1, \beta_2, \beta_3, β1,β2,β3, 和 β 4 \beta_4 β4 是平衡超参数,在实验中分别设置为 0.2, 0.1, 0.1, 和 0.6。
多尺度软监督的主要特点
-
缓解边界不确定性:通过在不同尺度上提供监督信号,确保模型在各个尺度上都能学习到准确的边界信息,减少了因下采样或上采样带来的边界信息丢失。
-
保持原始边界分布:调整后的真值图保留了原始掩码中的边界分布,使得模型能够更好地学习边界特征。
-
综合多种损失:结合了CNN、BLT、CGT模块各自的损失以及最终预测的交叉熵和Dice损失,使得模型能够在多个层次上进行优化,提高了整体分割性能。
IV. 评估
本节对BATFormer进行了广泛的对比实验,涵盖了最先进的基于CNN和基于Transformer(或混合)的方法,并在公开可用的数据集上进行了评估。
A. 数据集
评估使用了两个涵盖2D和3D医学图像数据的数据集:
-
ACDC:由150名患者的心脏MRI组成
-
ISIC 2018:包含2596张带有良好标注的皮肤病变图像