@inproceedings{lin2023interventional,
title={Interventional bag multi-instance learning on whole-slide pathological images},
author={Lin, Tiancheng and Yu, Zhimiao and Hu, Hongyu and Xu, Yi and Chen, Chang-Wen},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={19830--19839},
year={2023}
}
1、摘要
多示例学习是解决千兆像素分辨力和幻灯片级别标签的全幻灯片病理图像(whole-slide pathological images , WSIs)的有效工具。以往的主流MIL方法主要关注改进特征提取器以及聚合器。然而这样的方法存在这样的缺陷:包的上下文先验(contextual prior)可能会影响模型捕捉包与标签之间的虚假相关性(spurious correlations)。 该缺陷是限制现有MIL方法性能的混杂因素。
本文提出了新的方案:介入式的包多示例学习(Interventional Bag Multi-Instance
Learning , IBMIL),以实现卷积包级别的预测。
与传统的基于可能性的策略不同,该方案基于后门调节(backdoor adjustment)实现了介入学习,因此能够抑制包的上下文先验对相关性的影响。
IBMIL的机制与现有MIL包算法完全不同。因此,IBMIL能够为现有方案带来一致的性能提升,实现最先进的新性能。
2、引言
全幻灯片病理图像的分析对诊断与研究都至关重要(很多论文都是聚焦于WSIs图像)。
与自然图像不同,WSI拥有千兆像素的分辨率,同时为WSI图像进行注释与标签需要极大的成本。WSI特殊的数据形式与多示例学习问题密切联系:将每个WSI图像视为一个有标签的包,每一个patch视为无标签的实例。
WSI分类的主流方案:包级MIL如图所示:
首先,每个WSI中的patch代表一个实例,每个实例都由第一阶段中的特征提取器嵌入为向量。其次,对于每个包,它们的实例特征都在第二阶段聚合为包级特征用于分类。
越来越多的新方法大多都聚焦于改进该方案的后面两个阶段。研究人员坚信学习更好的实例特征和建模更准确的实例关系可以带来更好的MIL性能。
然而,现有的问题依然没有解决包的上下文先验问题(contextual prior):同类包共享的信息与标签无关,这可能会影响最终的预测。 例如下图所示,由于数据集的偏差,正包中的大多数实例染色为粉色,但是在负包中大多数实例染色为紫色。
这种特定颜色模式和标签的同时出现可能会误导模型,使其基于颜色而并非关键实例对包进行分类。
下图说明了另一个例子:即使预测是正确的,潜在的视觉注意力也是不合理的,因为高注意力得分被放在袋子中蓝色曲线之外的与疾病无关的实例上。
从因果角度来看,包的上下文先验是一个混杂因素,为袋子和标签打开了后门,导致它们之间的虚假相关性。
为了抑制这种偏差,我们需要一种更有效的机制来处理包与标签之间的实际因果关系,即包的预测是基于包的内容(如关键实例)。靠上述现有的新框架难以实现。事实上,实现无偏差的包预测是一种挑战,因为在数据生成过程(如组织制备、染色方案以及数字扫描)会发生许多的偏差。
在本文中,我们提出了一种新的MIL方案,即:介入式包多示例学习(Interventional Bag Multi-Instance Learning, IBMIL)。该方案的特别之处在于本文提出了一个结构因果模型(structure causal mdel, SCM)来分析上下文先验、包和标签之间的因果关系。同时,IBMIL与其他算法的关键差别在于它包含了另一个介入训练阶段。
给定在第二阶段的使用的聚合器,我们将其用于混杂因素的近似,而不是直接使用它来通过似然函数
P
(
Y
∣
X
)
P(Y|X)
P(Y∣X)进行推断(*******)。通过观察到的混杂因素,我们通过后门调整公式(其他文章中的,Causal diagrams for empirical research)来消除混杂因素的影响。其中,直观的理解是:如果WSI模型可以分别从“紫色”和“粉色”的正/负包中学习,那么袋子的颜色上下文将不再混淆识别。由于本文使用因果介入
P
(
Y
∣
d
o
(
X
)
)
P(Y|do(X))
P(Y∣do(X))进行包预测,使得IBMIL与现有方法有根本不同。
3、相关工作
分别介绍了基于包和基于实例的方法来处理WSIs数据。实例级别MIL方法通过分数表示每个实例,并将实例分数聚合到包分数中。但实例级MIL方法的性能通常不如包级方法。在包级方法中,实例表示为嵌入向量,并按包与包的距离/相似性或包分类器进行分类。
然而,在WSIs数据上进行包级分类是较为困难的,因为所有补丁的中间结果仍然需要存储在存储器中以进行反向传播。因此,最近提出的一些框架将实例级特征提取器和聚合网络的训练分开,从而产生了两阶段建模方法。贡献差异主要体现在两个阶段之中。
许多的现有工作聚焦设计全新的聚合网络,从非参数池(最大/平均池)到可学习的(图卷积网络、注意力机制)。本文的工作就在这方面,但旨在增强这些现有方法的能力。
因果推理是一个通用框架,已被引入各种计算机视觉任务。与其他引入因果推理的MIL算法不同,IBMIL基于后门调整公式,并作为一个通用框架,为WSI分类任务授权现有的包级MIL。
4、方法
4.1、预设
MIL问题:将每个WSI都被视为一个带标签的包,相应的补丁(patch)作为没有标签的实例。(后面就是标准的MIL假设)
一般的三阶段方法如下:
1)实例转换:为实例特征b训练特征提取器
f
(
⋅
)
f(·)
f(⋅);
2)实例组合:针对包特征B进行池化操作
σ
(
⋅
)
σ(·)
σ(⋅);
3)包转换:使用下游分类器g(·)进行预测,
以上过程可以公式化为:
b
i
=
f
(
x
i
)
,
B
=
σ
(
b
1
,
.
.
.
,
b
n
)
,
Y
^
=
g
(
B
)
(1)
b_i=f(x_i),B=\sigma(b_1,...,b_n),\hat{Y}=g(B)\tag{1}
bi=f(xi),B=σ(b1,...,bn),Y^=g(B)(1)
其中,池化函数
σ
(
⋅
)
\sigma(·)
σ(⋅)应为MIL方法的置换不变( permutation-invariant)函数。一些算法进一步将分类器
g
(
⋅
)
g(·)
g(⋅)加入到池化操作
σ
(
⋅
)
\sigma(·)
σ(⋅)中,称为聚合器/聚合网络。
当将MIL方法应用于WSI时,应该注意的是:
1)WSI分析的诊断可以基于具有多个概念的不同组织区域——集体MIL假设;
2)WSI的包长度n可能非常大,例如,平均约8000。
因此,用于WSI的包MIL方法在两阶段过程中进行训练:逐阶段同时训练特征提取器和聚合器。目前的工作主要遵循这一准则,并从特征提取器和聚合器两个方面改进了框架,而我们提出的方法旨在从因果角度(causal perspective)改进现有工作。
通过因果推断分析MIL:如图所示,本文将MIL框架转化为因果图,包含三个节点:
X
X
X:整张幻灯片的病例图形,即包;
Y
Y
Y:包标签;
C
C
C:包上下文信息。
X
→
Y
X→Y
X→Y:该路径表明MIL模型可以通过包的内容(如关键实例)来预测包标签;
C
→
X
C→X
C→X:该路径表明了WSI的生成。由于组织制备、染色方案和数字扫描仪的差异,WSI的外观可能会受到显著影响,从而可能引入偏差。
C
→
Y
C→Y
C→Y:该路径表明预测包标签收到数据集的上下文先验信息的影响。
以一个例子来说明这种影响:如下图的左侧所示,上面的图像为正,下面的为负。但因为都染成了粉色,MIL模型受到了先验信息的影响,于是将粉色的图像都预测为正,而与真实标签相关的内容信息无关。
理想的MIL方法应该捕捉
X
X
X和
Y
Y
Y之间的真实因果关系,但使用
P
(
Y
∣
X
)
P(Y|X)
P(Y∣X)的传统相关性无法做到这一点。
因此,本文使用因果介入
P
(
Y
∣
d
o
(
X
)
)
P(Y|do(X))
P(Y∣do(X)),其中
d
o
(
⋅
)
do(·)
do(⋅)代表着强制为
X
X
X赋值。在因果图的右侧可以看到:直接切断
C
C
C对
X
X
X的影响,从而减轻混杂因素造成的偏差。
4.2、Interventional Bag Multi-Instance Learning
本文提出了后门调整公式来实现为包级别预测的因果介入
P
(
Y
∣
d
o
(
X
)
)
P(Y|do(X))
P(Y∣do(X)):
P
(
Y
∣
d
o
(
X
)
)
=
∑
i
P
(
Y
∣
X
,
h
(
X
,
c
i
)
)
P
(
c
i
)
(2)
P(Y|do(X))=\sum_iP(Y|X,h(X,c_i))P(c_i)\tag{2}
P(Y∣do(X))=i∑P(Y∣X,h(X,ci))P(ci)(2)
其中,
c
i
c_i
ci
因为因果介入迫使
X
X
X公平的合并每个
c
I
c_I
cI,
c
i
c_i
ci不再受
X
X
X的影响而是受
P
(
c
i
)
P(c_i)
P(ci)的影响。
下图为IBMIL的总体框架图,主要分为三个阶段:
阶段1:训练特征提取器
在WSIs图像上学习特征提取器 f ( ⋅ ) f(·) f(⋅),旨在将每个实例编码为判别特征向量。
阶段2:训练聚合器
给定阶段1提取好的实例特征
{
b
1
,
.
.
.
,
b
n
}
\left \{ b_1,...,b_n \right \}
{b1,...,bn},利用使用了MIL池化
σ
(
⋅
)
\sigma(·)
σ(⋅)的聚合器将这些实例特征组合为包特征
B
B
B,并使用分类器
g
(
⋅
)
g(·)
g(⋅)分类。训练聚合器的损失函数定义为:
L
=
−
1
N
∑
i
=
1
N
Y
i
log
Y
^
i
+
(
1
−
Y
i
)
l
o
g
(
1
−
Y
^
i
)
(3)
L=-\frac{1}{N}\sum_{i=1}^{N}Y_i\log\hat{Y}_i+(1-Y_i)log(1-\hat{Y}_i)\tag{3}
L=−N1i=1∑NYilogY^i+(1−Yi)log(1−Y^i)(3)
其中,
N
N
N表示训练集中的包数量。该损失函数就是来度量预测标签跟真实标签的差距,并且求了均值。损失函数使得预测标签更加接近真实标签。
IBMIL并不局限于特定的聚合器或特征提取器。
阶段3:通过后门调整进行因果干预
传统的方法在阶段2时就结束了。而本文引入了一个从头开始的干预训练阶段,这需要公式(2)的具体实施。后门调整假设我们可以观察包上下文的混杂因素。
由于深度学习MIL模型的强大,上下文信息自然地被编码在更高层。为了构成混杂因素集合。因为收集所有混杂因素是不可能的,所以本文使用混杂因素字典 C = [ c 1 , . . . , c K ] C=[c_1,...,c_K] C=[c1,...,cK]进行近似。
给定训练好的特征提取器和聚合器,我们对训练集中的所有包特征
B
B
B使用k-means聚类,将包分成几个簇。对每个聚类簇中的包特征进行平均,以表示混杂层
c
i
c_i
ci,从而形成
d
×
K
d×K
d×K的混杂字典,其中
d
d
d为包特征维度,
K
K
K为聚类簇数量。
因为这些在全局进行聚类容易受到视觉偏差的影响,而视角偏差正是混杂因素,因此可以通过聚类来近似混杂因素。
定义:
h
(
X
,
c
i
)
=
α
i
c
i
(4)
h(X,c_i)=\alpha_ic_i\tag{4}
h(X,ci)=αici(4)
[
α
1
,
.
.
.
,
α
K
]
=
s
o
f
t
m
a
x
(
(
W
1
B
)
T
(
W
2
C
)
l
)
\left [ \alpha _1,..., \alpha _K\right ] =softmax(\frac{\left ( W_1B \right )^T\left ( W_2C \right ) }{\sqrt{l} })
[α1,...,αK]=softmax(l(W1B)T(W2C))
其中,
B
=
σ
(
f
(
X
)
)
B=\sigma(f(X))
B=σ(f(X))是包特征,
W
1
,
W
2
W_1,W_2
W1,W2是将包特征
B
B
B和混杂因素
C
C
C投影到联合空间中的两个可学习投影矩阵,
l
\sqrt{l}
l用于特征归一化。
由于预测由包
X
X
X和混杂因素
C
C
C共同决定,因此进一步定义:
P
(
Y
∣
X
,
h
(
X
,
c
i
)
)
=
P
(
Y
∣
B
⊕
h
(
X
,
c
i
)
)
(5)
P(Y|X,h(X,c_i))=P(Y|B\oplus h(X,c_i))\tag{5}
P(Y∣X,h(X,ci))=P(Y∣B⊕h(X,ci))(5)
其中,
⊕
\oplus
⊕表示向量拼接。
将公式(4)与公式(5)的内容插入到公式(2)中,并通过多次传递网络来计算
P
(
Y
∣
d
o
(
X
)
)
P(Y|do(X))
P(Y∣do(X)),结合归一化加权几何平均:
P
(
Y
∣
d
o
(
X
)
)
≈
P
(
Y
∣
B
⊕
∑
i
=
1
K
α
i
c
i
P
(
c
i
)
)
(6)
P(Y|do(X))\approx P(Y|B\oplus \sum_{i=1}^K\alpha _ic_iP(c_i))\tag{6}
P(Y∣do(X))≈P(Y∣B⊕i=1∑KαiciP(ci))(6)
5、Justification
应用在大规模无标签数据集:
本文以无监督的方式(没有标签)构成混杂因素集合。但同时,可以选用包标签对聚类过程进行引导,从而保留类内差异并捕捉类相关特征。
更简便的方案:
由于我们需要经过训练的聚合器来生成袋子特征(阶段2),因此还需要一个阶段来重新训练聚合器(阶段3)。因此,我们能够简化这一方案。
具体来说,我们可以通过将传统的非参数聚合器(例如,最大/平均池)应用于包中的实例来实现包的特征。
5、实验
数据集与评测机制:
我们在两个公共的WSI数据集上进行了实验,即Camelyon16和TCGA-NSCLC。
我们为Camelyon16使用了270张训练图像和129张测试图像,为TCGA-NSCLC使用了836张训练图像,并使用了210张测试图像(丢弃了一些损坏的幻灯片)。我们报告了分类精度、召回率、准确性和曲线下面积(AUC)得分。
特征提取器:
我们采用不同的网络架构和不同的训练模式来全面评估我们的IBMIL。诸如:ResNet-18 、ViT-small和CTransPath。
MIL模型聚合器:
我们在4种方法的基础上构建了本文的方法。诸如:ABMIL、DSMIL、Trans-MIL和DFTD-MIL。为了与DSMIL保持一致,我们使用最大注意力分数选择(MaxS)作为特征提取策略。
由于IBMIL的特征提取器是经过预训练的,因此能够直接将实例转换为特征向量(第一阶段)。对于第二跟第三阶段,所有的MIL模型都经过50次迭代,学习率为0.0001。将聚类簇数设置为 K = 8 K=8 K=8,投影维度 l = 128 l=128 l=128。