1 要点
代码:
研究目的:
WSI分辨率巨大且缺乏细粒度注释,使其分类和分析任务面临挑战。此外,现有基于注意力机制的方法在处理WSI时存在计算复杂度高和无法捕捉实例间信息的问题。因此,本文所提出的AMD-MIL方法旨在通过动态调整特征表示,并引入掩码去噪机制来改进注意力分配,从而提高对WSI的检测能力和模型的可解释性。
关键技术:
- 代理聚合器:
引入代理Token,作为查询和键之间的中间变量,用于计算实例的重要性; - 掩码去噪机制:
利用从代理聚合值映射的掩码和去噪矩阵,动态地掩盖低贡献表示并消除噪声; - 可训练代理Token:
将代理令牌从通过均值池化得到的非训练参数转换为可训练的矩阵,以更有效地映射特征。
数据集:
- Camelyon16
- Camelyon16
- TCGA-KIDNEY
- TCGA-LUNG
引用:
@inproceedings{Ling:2024:agent:19,
author = {Ling, Xitong and Ouyang, Minxi and Wang, Yizhi and Chen, Xinrui and Yan, Renao and Cheng, Junru and Guan, Tian and Liu, Xiaoping and Tian, Sufang and He, Yonghong},
title = {Agent aggregator with mask denoise mechanism for histopathology whole slide image analysis},
booktitle = {{ACM MM}},
years = {2024},
pages = {1--9},
url = {https://openreview.net/pdf?id=qQTr32f832}
}
2 方法
2.1 MIL和特征提取
在多示例学习 (MIL)中,每张WSI被概念化为一个有标记的包,其中包含的各个切片被视为具有不确定标签的实例。以二分类为例,每个WSI
X
X
X被划分为多个区块
{
(
x
1
,
y
1
)
,
…
,
(
x
N
,
y
N
)
}
\{(x_1,y_1), \ldots , (x_N,y_N)\}
{(x1,y1),…,(xN,yN)},其中
N
N
N是实例数量。在MIL框架下,包标签
Y
Y
Y与实例标签
y
i
y_i
yi之间的相关性如下:
Y
=
{
1
,
iff
∑
i
=
1
N
y
i
>
0
;
0
,
else
,
(1)
\tag{1} Y = \begin{cases} 1, & \text{iff } \sum_{i=1}^{N} y_i > 0;\\ 0, & \text{else}, \end{cases}
Y={1,0,iff ∑i=1Nyi>0;else,(1)在训练期间,实例
y
i
y_i
yi的标签不可用。此外,WSI分类的目标是开发一个分类器
M
(
X
)
\mathcal{M}(X)
M(X),用于预测包标签
Y
^
\hat{Y}
Y^:
Y
^
←
M
(
X
)
:
=
h
(
g
(
f
(
X
)
)
)
)
,
(2)
\tag{2} \hat{Y} \leftarrow \mathcal{M}(X) := h(g(f(X)))),
Y^←M(X):=h(g(f(X)))),(2)其中
f
f
f、
g
g
g,以及
h
h
h分别代表特征提取器、特征聚合器,以及MIL分类器。
特征聚合器被认为是总结特征的最重要部分,它可以聚合不同区块的特征。注意力机制可以识别WSI中区块的重要性,并且在特征聚合器中广泛使用:
- 基于注意力机制的特征聚合器:
G = ∑ i = 1 N a i h i = ∑ i = 1 N a i f ( x i ) ∈ R D , (3) \tag{3} G = \sum_{i=1}^{N} a_i h_i=\sum_{i=1}^Na_if(x_i)\in \mathbb{R}^D, G=i=1∑Naihi=i=1∑Naif(xi)∈RD,(3)其中 G G G是包表示、 h i ∈ R D h_i \in \mathbb{R}^D hi∈RD是通过特征提取器 f f f提取的区块 x i x_i xi的特征、 a i a_i ai可训练注意力系数,以及 D D D是向量 G G G和 h i h_i hi的维度; - 基于自注意力的特征聚合器:
Q = H W Q , K = H W K , V = H W v , (4) \tag{4} Q=HW_Q,\ K=HW_K,\ V=HW_v, Q=HWQ, K=HWK, V=HWv,(4) O = softmax ( Q K T d q ) V = S V (5) \tag{5} O=\text{softmax}\left( \frac{QK^T}{\sqrt{d_q}} \right) V=SV O=softmax(dqQKT)V=SV(5)其中 W q W_q Wq、 W K W_K WK,以及 W v W_v Wv分别代表的可训练参数矩阵、 H H H表示区块特征的集合、 O O O是所有特征的整合,以及 d q d_q dq是查询向量的维度。
2.2 注意力聚合器
在公式5,
Sim
(
Q
,
K
)
\text{Sim}(Q, K)
Sim(Q,K)的计算复杂度为
O
(
N
2
)
O(N^2)
O(N2)。由于
N
N
N通常包含数千个元素,这显著延长了预期的计算时间。线性注意力可以减少计算时间,但以牺牲信息为代价。为了缓解这一问题,TransMIL采用了Nyström近似来处理公式5中的自注意力,即构建矩阵
Q
~
\tilde{Q}
Q~和
K
~
\tilde{K}
K~,并计算每个片段 (segment) 的均值如下:
Q
~
=
[
q
~
1
;
…
;
q
~
m
]
,
q
~
j
=
1
m
∑
i
=
(
j
−
1
)
×
l
+
1
(
j
−
1
)
×
l
+
m
q
i
,
∀
j
=
1
,
…
,
m
(6)
\tag{6} \tilde{Q} = [ \tilde{q}_1; \ldots; \tilde{q}_m], \ \tilde{q}_j = \frac{1}{m} \sum_{i=(j-1)\times l+1}^{(j-1)\times l+m} q_i,\ \forall j = 1, \ldots, m
Q~=[q~1;…;q~m], q~j=m1i=(j−1)×l+1∑(j−1)×l+mqi, ∀j=1,…,m(6)
K
~
=
[
k
~
1
;
…
;
k
~
m
]
,
k
~
j
=
1
m
∑
i
=
(
j
−
1
)
×
l
+
1
(
j
−
1
)
×
l
+
m
k
i
,
∀
j
=
1
,
…
,
m
(7)
\tag{7} \tilde{K} = [ \tilde{k}_1; \ldots; \tilde{k}_m], \ \tilde{k}_j = \frac{1}{m} \sum_{i=(j-1)\times l+1}^{(j-1)\times l+m} k_i,\ \forall j = 1, \ldots, m
K~=[k~1;…;k~m], k~j=m1i=(j−1)×l+1∑(j−1)×l+mki, ∀j=1,…,m(7)其中,
Q
~
,
K
~
∈
R
m
×
D
\tilde{Q},\tilde{K} ∈ \mathbb{R}^{m\times D}
Q~,K~∈Rm×D。然后,公式5中的
S
S
S可以被近似为:
S
^
=
softmax
(
Q
K
~
T
d
q
)
Z
∗
softmax
(
Q
~
K
T
d
q
)
,
(8)
\tag{8} \hat{S} = \text{softmax}\left( \frac{Q \tilde{K}^T}{\sqrt{d_q}} \right)Z^*\text{softmax}\left( \frac{\tilde{Q} K^T}{\sqrt{d_q}} \right),
S^=softmax(dqQK~T)Z∗softmax(dqQ~KT),(8)其中,
Z
∗
Z^*
Z∗表示
z
(
Q
~
,
K
~
,
Z
)
=
0
z( \tilde{Q}, \tilde{K}, Z) = 0
z(Q~,K~,Z)=0的近似解,其仅需线性数量的迭代来收敛。
难点:
- 主要在于理解片段,因为公式6–7中的 m m m和 l l l未解释清楚,这需要提前了解transmil才能了解;
- 简单来说,给定的WSI,其有 N N N个实例,所有的实例可以按顺序被分为 L L L个片段,这里的 L L L是为了和本文的 l l l对应,原文并没有这个符号;
- 每个片段的大小为 m m m,因此公式6的含义即为:在给定的第 l l l个片段内,如何通过索引来计算新的 q ~ j \tilde{q}_j q~j,公式7同理;
在MIL任务中,Nyström注意力会因为采样机制而过滤掉具有重要特征的切片。此外,
N
N
N的差异会导致局部下采样时的整体不平衡。因此,本文考虑具有线性时间复杂度的代理注意力方法:
O
=
σ
(
Q
A
T
)
σ
(
A
K
T
)
V
(9)
\tag{9} O = \sigma(Q A^T)\sigma(AK^T)V
O=σ(QAT)σ(AKT)V(9)其中
σ
(
⋅
)
\sigma(\cdot)
σ(⋅)是Softmax函数、
A
∈
R
n
×
D
A ∈ \mathbb{R}^{n\times D}
A∈Rn×D是池化
Q
Q
Q得到的代理矩阵 (这里的使用的平均池化),以及
n
n
n是代理的维度,其是一个超参数。由于代理是非训练的,并且注意力分数的分布可能不是最优的,因此有必要建立一个能够动态调整注意力分数分布以增强模型性能和灵活性的自适应代理。
2.3 代理掩码去噪机制
公式5和公式9构成了本文的基底,如图1。图2则展示了所提出方法的总体框架。在预处理特征输入到模型之前,一个类别Token与这些特征直接拼接,这一点和transmil类型,得到特征矩阵 D ∈ R D × ( N + 1 ) D \in \mathbb{R}^{D \times (N+1)} D∈RD×(N+1),其中 D D D是特征的维度,以及 ( N + 1 ) (N+1) (N+1)表示包括嵌入的类别Toekn在内的区块数量。
3.3.1 可训练代理
公式9中的矩阵 A A A最初是在 Q Q Q的基础上,使用均值池化获得,即 A = p o o l i n g ( Q ) ∈ R n × D A = pooling(Q) \in \mathbb{R}^{n \times D} A=pooling(Q)∈Rn×D,这限制了 Q Q Q中信息的完整性。对此,本文将 A A A定义为一个可训练的矩阵:
- 基于矩阵 A ∈ R n × D A \in \mathbb{R}^{n \times D} A∈Rn×D,可以计算中间矩阵 Q A = Q A T ∈ R ( N + 1 ) × n Q_A = QA^T \in \mathbb{R}^{(N+1) \times n} QA=QAT∈R(N+1)×n和$K_A = AK^T \in \mathbb{R}^{n \times (N+1)} );
- 使用通用注意力策略,中间变量
V
A
V_A
VA计算为:
V A = σ ( K A ) V = σ ( A K T ) V ∈ R n × D . (10) \tag{10} V_A = \sigma(K_A)V=\sigma(AK^T)V \in \mathbb{R}^{n \times D}. VA=σ(KA)V=σ(AKT)V∈Rn×D.(10)
难点:
- n n n:可训练矩阵 A A A的节点参数,具体可以参见算法1,其实就是在网络初始化时,生成一个 B × n × D B\times n \times D B×n×D的矩阵。
3.3.2 掩码代理
在MIL任务中,WSI中的大多数区块对预测的贡献不大,因此使用可训练的阈值生成可学习的掩码:
τ
=
σ
(
p
(
W
τ
V
A
T
)
)
,
(11)
\tag{11} \tau = \sigma(p (W_\tau V_A^T)),
τ=σ(p(WτVAT)),(11)其中
W
τ
∈
R
1
×
D
W_\tau \in \mathbb{R}^{1 \times D}
Wτ∈R1×D、
p
p
p是一个类似于均值池化的可调节聚合函数,以及
τ
\tau
τ阈值。进一步,每个特征的重要性将被计算,以优化隐藏空间中的重要特征。然而,特征选择会有丢失信息的风险。为了平衡重要信息选择和聚合原始特征之间的特性,提出了一个新的模块:
V
M
D
i
j
=
V
A
i
j
I
(
M
i
j
>
τ
)
+
D
N
i
j
,
(12)
\tag{12} V_{MD_{ij}} = V_{A_{ij}}\mathbb{I}(M_{ij}>\tau)+DN_{ij},
VMDij=VAijI(Mij>τ)+DNij,(12)其中
M
=
W
M
V
A
M=W_MV_A
M=WMVA是阈值矩阵,用于获得每个特征的重要性、
D
N
=
W
D
N
V
A
DN=W_{DN}V_A
DN=WDNVA是用于聚合信息的去噪矩阵,以及
W
M
W_M
WM和
W
D
N
W_{DN}
WDN是可学习的参数。
3.3.3 代理可视化
代理注意力无法像通常注意力机制那样为每个实例生成一个注意力分数,这使得实验中的区块预测可视化变得困难。对此,提出了一种注意力分数的可视化策略:
A
t
t
i
=
∑
j
=
1
n
Q
A
0
,
j
K
A
j
,
i
+
1
(13)
\tag{13} Att_i = \sum_{j=1}^{n} Q_{A_{0,j}}K_{A_{j,i+1}}
Atti=j=1∑nQA0,jKAj,i+1(13)其中
A
t
t
i
Att_i
Atti就可以作为特征
h
i
h_i
hi的注意力分数。
3.3.4 AMD
引入了一个名为掩码去噪机制的新框架,如图2所示,其包括基于学习的代理注意力机制、表示细化。以及特征聚合。算法过程如算法1,且该模块可以表示为:
O
=
σ
(
Q
A
T
)
V
M
D
,
(14)
\tag{14} O = \sigma(QA^T)V_{MD},
O=σ(QAT)VMD,(14)其中
m
d
md
md表示掩码去噪机制,以及
V
M
D
V_{MD}
VMD表示基于公式12计算出的矩阵。
由于阈值选择方法的差异,另外两种特征阈值选择策略如下:
- Mean-AMD:均值选择,选择所有特征的平均值作为阈值;
- CNN-AMD。CNN选择,通过组卷积方法减少不同组的特征,并使用组间平均值作为阈值。