文章目录
1 概述
1.1 题目
1.2 摘要
多示例学习 (MIL) 广泛应用于自动全幻灯片图像 (Whole slide image, WSI) 分析,其处理策略可以分为:
- 实例特征提取;
- 特征聚合。
然而,由于幻灯片级别标签的弱监督性,MIL模型的训练过程通常会呈现严重的过拟合。在这种情况下,从有限的幻灯片级别标注的数据中发掘更多的信息是至关重要的。
与已有的方法不同,本文着重于探索不同实例 (区块) 之间的潜在关系,而非提升实例特征的提取,以提高模型的泛化能力。具体地,MuRCL从以下几个角度处理问题:
- 训练自监督管理器,然后基于WSI幻灯片级别标签微调。这个过程被称为对比学习 (Contrastive learning, CL),其基于WSI中相同的区块级特征包构建了正/负判别特征集;
- 为了加速CL训练,设计了一个基于强化学习的代理,根据幻灯片级特征聚合的在线奖励 (Online reward) 以逐步更新辨别特征集的选择。然后使用标记的WSI数据来更新模型和习得特征,并获取最终的WSI分类。
实验在三个公开的WSI分类数据集上进行,包括Camelyon16、TCGA-Lung,以及TCGA-Kidney。实验结果验证了MuRCL的性能,其在TCGA-Lung数据集的效果尤为明显。
图1展示了MuRCL与一般的MIL的区别。
1.3 代码
Torch:https://github.com/wwu98934/MuRCL
1.4 引用
@article{Zhu:2022:113,
author = {Zhong Hang Zhu and Le Quan Yu and Wei Wu and Rong Shan Yu and De Fu Zhang and Lian Sheng Wang},
title = {{MuRCL}: {M}ulti-instance reinforcement contrastive learning for whole slide image classification},
journal = {{IEEE} Transactions on Medical Imaging},
pages = {1--13}
year = {2022},
doi = {10.1109/TMI.2022.3227066}
}
2 方法
图2:MuRCL的自监督学习过程:(a) MuRCL:给定一个WSI的输入特征包 (WSI-Fbag)
x
x
x,两个RL-MIL分支的输出是用于对比损失的正对;(b) RL-MIL:一个奖励导向的代理
R
\boldsymbol{R}
R被用于从
x
x
x选择辨别特征集 (WSI-Fset)
x
~
\tilde{x}
x~。
x
~
\tilde{x}
x~首先被随机初始化,然后通过
R
\boldsymbol{R}
R更新。
x
~
\tilde{x}
x~传递给MIL聚合器,以生成特征嵌入
v
v
v,并通过映射头
f
(
⋅
)
f(\cdot)
f(⋅)输出
p
p
p,其将用于计算对比损失;(
c
\text{c}
c) RL选择:
s
t
s_t
st用于从由
s
t
k
s_t^k
stk构成的关联图的每个簇中选出索引特征
图2(a)展示了MuRCL的总体框架,其通过代理构建用于对比学习的正/负对,并从输入特征包中选择两个独立的辨别集。然后,MIL汇聚器 M ( ⋅ ) M(\cdot) M(⋅)和投影头 f ( ⋅ ) f(\cdot) f(⋅)使用对比损失来最大化正辨别集之间的一致性。MuRCL的每个分支是一个RL-MIL。
图2(b)展示了RL-MIL的序列决策过程。给定输入包,RL-MIL迭代生成一个辨别集的序列,并使用MIL聚合器和投影头输出特征向量序列。特别地,在每一步,代理 (辨别集网络) 确定一组特征索引。然后,通过组合每个包簇的索引特征来构建下一步的判别集,如图2( c \text{c} c)。最后利用WSI标签微调自监督预训练模型,以获取用于预测的表示。
2.1 多示例对比学习
多示例对比学习使用WSI-Fbag作为输入,其中每个包包含了在ImageNet上训练的ResNet18处理后的区块级嵌入。CL的一个关键步骤是构建用于训练的逻辑正/负对 (即语义相似/不相似实例)。与已有的基于图像增强的策略不同,我们从每个WSI-Fbag中采样不同的WSI辨别集合 (更小的WSI-Fset),并构建用于CL训练的基于集合的正/负对。
每个WSI-Fset是从WSI-Fbag的多个特征簇得到的子集的组合。特别地,给定WSI0Fbag x x x,首先使用聚类算法,例如kMeans将其划分为 K K K个簇 C k ( k ∈ [ 1 , 2 , … , K ] ) C_k(k\in[1,2,\dots,K]) Ck(k∈[1,2,…,K])。通过从 C k C_k Ck采样到一个子集,WSI-Fset x ~ \tilde{x} x~是所有这些子集的拼接。每一个簇的采样率保持一致,因此 x ~ \tilde{x} x~有常量个实例嵌入。
随后, x ~ \tilde{x} x~被传递给MIL聚合器 M ( ⋅ ) M(\cdot) M(⋅)以及投影头 f ( ⋅ ) f(\cdot) f(⋅),以获取WSI级别特征嵌入 p p p。值得注意的是这里有多种采样策略,因此能够获得不同的 x ~ \tilde{x} x~以及相应的嵌入。不同的 x ~ \tilde{x} x~可以看作是相同WSI的不同视角,可被用于构建CL中的正对。
令
{
x
~
n
}
n
=
1
N
\{\tilde{x}_n\}_{n=1}^N
{x~n}n=1N表示
N
N
N个WSI-Fset的组,其中
x
~
i
\tilde{x}_i
x~i和
x
~
j
\tilde{x}_j
x~j采样自同一个
x
x
x,其他的采样自不同的WSI-Fbag。然后,CL损失计算为:
L
i
,
j
=
−
log
exp
(
s
i
m
(
p
i
,
p
j
)
/
τ
)
∑
n
=
1
N
1
(
n
≠
i
)
exp
(
s
i
m
(
p
i
,
p
n
)
/
τ
)
,
(1)
\tag{1} L_{i,j}=-\log\frac{\exp(sim(p_i,p_j)/\tau)}{\sum_{n=1}^N1(n\neq i)\exp(sim(p_i,p_n)/\tau)},
Li,j=−log∑n=1N1(n=i)exp(sim(pi,pn)/τ)exp(sim(pi,pj)/τ),(1)其中
τ
\tau
τ是温度参数、
s
i
m
(
⋅
,
⋅
)
sim(\cdot,\cdot)
sim(⋅,⋅)表示两个向量之间的余弦相似度、
1
(
n
≠
i
)
∈
{
0
,
1
}
1(n\neq i)\in\{0,1\}
1(n=i)∈{0,1}是一个指示函数,即如果条件满足则为
1
1
1,反之为
0
0
0。本文使用NT-Xent作为目标函数,以最大化正对之间的相似度和最小化负对之间的相似度。在这种情况下,MIL聚合器将能够习得用于准确分类的聚合知识。
2.2 基于RL的辨别集构建
如前所述,MuRCL中的一个重要步骤是如何构建WSI辨别集 (WSI-Fset)。因此我们提出了一个新颖的基于强化学习的策略,记为RL-MIL,其基于WSI-Fbag构建WSI-Fset,且其中的一个WSI代理 R R R (循环神经网络) 通过强化学习来训练。如图2(b)所示,WSI-Fset的构建过程是一个序列决策。在每一步中,MIL聚合器 M ( ⋅ ) M(\cdot) M(⋅)和映射头 f ( ⋅ ) f(\cdot) f(⋅)使用WSI-Fset作为输入,获取相应的语义预测 p p p。与此同时, R R R利用输入的特征向量 v v v生成另一个WSI-Fset提案 s s s。
特别地,对于输入 x x x,RL-MIL迭代地生成一个序列 { x ~ 0 , … , x ~ t , … , } \{\tilde{x}_0,\dots,\tilde{x}_t,\dots,\} {x~0,…,x~t,…,}。在第 t t t次迭代时, M ( ⋅ ) M(\cdot) M(⋅)收到当前的 x ~ t \tilde{x}_t x~t,并输出特征向量 v t v_t vt,并通过 f ( ⋅ ) f(\cdot) f(⋅)获取幻灯片级嵌入 p t p_t pt。然后, p t p_t pt被传递给公式1。与此同时,WSI-Fset提案代理 R R R将 v t v_t vt作为输入来决定下一次 x ~ t + 1 \tilde{x}_{t+1} x~t+1的选择特征索引的行为 s t + 1 s_{t+1} st+1。随后, x ~ t + 1 \tilde{x}_{t+1} x~t+1从 x x x中选择,如图2( c \text{c} c)。如此,迭代一轮。
M ( ⋅ ) M(\cdot) M(⋅)使用ABMIL和CLAM, f ( ⋅ ) f(\cdot) f(⋅)和提案代理 R R R均使用循环神经网络,这样它们就可以通过在其中分别保持隐藏状态 h t − 1 R h^R_{t -1} ht−1R和 h t f h^f_t htf来探索所有先前输入的信息。注意本文的策略并不是一个RL框架,而是利用RL中的优化策略来优化自己的方法。在MuRCL中,RL用于辅助MIL辨别集构建。在集合构建阶段,代理多次扫描WSI以定位辨别特征,计算奖励并更新下次决策的代理。因此,本文借鉴的是强化学习的思想。
2.2.1 RL选择
在RL辨别集合构建的每一步中,特征索引从WSI-Fbag中选择WSI-Fset以供后续步骤。为了加速代理生成一个空间联系WSI-Fset提案,我们根据 x x x的聚类标签将其特征从排序,即具有相同聚类标签的特征将被分配相邻的索引。随后,对于每一个簇 C k C_k Ck中的特征,沿着它们相应区块的坐标重排列它们。重新排序后的WSI-Fbag被称为关联图,如图2( c \text{c} c)。然后,根据代理 R R R预测的动作 s s s从关联图中合成WSI-Fset,其中我们将动作 s s s制定为重新排列的集群的一组特征索引。特别地,第 t t t步的特征索引向量 s t ∈ R K ∗ 1 s_t\in\mathbb{R}^{K*1} st∈RK∗1是 R R R在上一步获取的,其元素 s t k s_t^k stk表示第 k k k个 重排列簇的特征索引。因此,对于簇 C k C_k Ck,从第 s t k s_t^k stk个特征开始采样一个长度为采样率乘以特征维度的序列,然后,拼接不同簇的序列得到 x ~ t \tilde{x}_t x~t。
2.2.2 奖励
WSI-Fset提案代理
R
R
R使用策略梯度法训练。在训练阶段,奖励函数用于控制
R
R
R的优化方向。本文利用WSI-Fset两个正对之间的相似性作为奖励来引导代理
R
R
R定位信息化特征。特别地,在第
t
t
t步,奖励函数为:
r
i
,
j
;
t
=
s
i
m
(
p
i
;
t
−
1
,
p
j
;
t
−
1
)
−
s
i
m
(
p
i
;
t
−
p
j
;
t
)
,
(2)
\tag{2} r_{i,j;t}=sim(p_{i;t-1},p_{j;t-1})-sim(p_{i;t}-p_{j;t}),
ri,j;t=sim(pi;t−1,pj;t−1)−sim(pi;t−pj;t),(2)其中
p
i
p_i
pi和
p
j
p_j
pj分别来自
x
~
i
\tilde{x}_i
x~i和
x
~
j
\tilde{x}_j
x~j。通过正WSI-Fset之间余弦距离,这将促使MIL模型通过最小化CL来关注于潜在的汇聚知识。
2.2.3 辨别集混合
为了在MIL聚合器训练时引入更多的扰动,本文使用了一种高效的特征汇聚策略,称为集合混合 (set-mixup),以增加WSI-Fset的多样性。对于训练批次中的WSI-Fset,混合
x
~
l
\tilde{x}_l
x~l到
x
~
q
\tilde{x}_{q}
x~q中生成增强表示
x
‾
q
\overline{x}_q
xq:
x
‾
q
=
λ
x
~
q
+
(
1
−
λ
)
x
~
l
,
(3)
\tag{3} \overline{x}_q=\lambda\tilde{x}_q+(1-\lambda)\tilde{x}_l,
xq=λx~q+(1−λ)x~l,(3)其中
λ
\lambda
λ是一个从
U
(
α
,
1.0
)
U(\alpha,1.0)
U(α,1.0)分布中的采样,本文设置
α
=
0.9
\alpha=0.9
α=0.9。这种混合用以增强语义概念学习。
2.3 RL-MIL的训练策略
如图2(a)所示,对比学习框架有两个分支,而两个分支中
M
(
⋅
)
M(\cdot)
M(⋅)、
f
(
⋅
)
f(\cdot)
f(⋅),以及
R
R
R的参数是共享的。在一个训练批次中,首先从WSI-Fbag中随机选择两个WSI-Fset作为初始化的正对。然后,正对将使用RL策略训练的代理重构。为了清晰表示,本文使用下表
(
⋅
)
i
(\cdot)_i
(⋅)i和
(
⋅
)
j
(\cdot)_j
(⋅)j表示两个分支,
{
(
⋅
)
t
}
t
=
0
T
\{(\cdot)_t\}_{t=0}^T
{(⋅)t}t=0T 表示每个分支生成的时间序列。这里
T
=
5
T=5
T=5,表示RNN在每个训练分支上运行五次。在开始时,正对表示为
x
~
i
;
0
\tilde{x}_{i;0}
x~i;0和
x
~
j
;
0
\tilde{x}_{j;0}
x~j;0,并通过
M
(
⋅
)
M(\cdot)
M(⋅)生成两个不同的特征向量
v
i
;
0
v_{i;0}
vi;0和
v
j
;
0
v_{j;0}
vj;0。与此同时,WSI级特征嵌入
p
i
;
0
p_{i;0}
pi;0和
p
j
;
0
p_{j;0}
pj;0被计算并用以计算CL损失
L
0
=
L
i
,
j
;
0
(
p
i
;
0
,
p
j
;
0
)
L_0=L_{i,j;0}(p_{i;0},p_{j;0})
L0=Li,j;0(pi;0,pj;0),其中来自不同WSI-Fbag的WSI-Fset被看作是负对。接下来执行第一次迭代,通过将
v
i
;
0
v_{i;0}
vi;0和
v
j
;
0
v_{j;0}
vj;0作为
R
R
R的初始状态来生成新的正对,对于新的正对,又开始同样的计算。在五次迭代中,
f
(
⋅
)
f(\cdot)
f(⋅)被同步处理,其两个分支的输出
{
p
i
;
t
,
p
j
;
0
}
t
=
10
5
\{p_{i;t},p_{j;0}\}_{t=10}^5
{pi;t,pj;0}t=105用于计算CL损失:
L
t
=
∑
t
=
0
5
L
i
,
j
(
p
i
;
t
,
p
j
;
t
)
L_t=\sum_{t=0}^5L_{i,j}(p_{i;t},p_{j;t})
Lt=∑t=05Li,j(pi;t,pj;t)。此外,
R
R
R需要最大化奖励
∑
t
=
1
5
γ
t
−
1
r
i
,
j
;
t
(
p
i
;
t
,
p
j
;
t
)
\sum_{t=1}^5\gamma^{t-1}r_{i,j;t}(p_{i;t},p_{j;t})
∑t=15γt−1ri,j;t(pi;t,pj;t),其中
γ
=
0.1
\gamma=0.1
γ=0.1;使用后了两种采样策略从
N
N
N个WSI-Fbag中生成了
2
N
2N
2N个WSI-Fset。该过程如算法1。
训练过程包含三个阶段:
- 随机采样WSI-Fset来训练 M ( ⋅ ) M(\cdot) M(⋅)和 f ( ⋅ ) f(\cdot) f(⋅),该阶段用于确保模型能够处理任意大小的序列;
- 固定 M ( ⋅ ) M(\cdot) M(⋅)和 f ( ⋅ ) f(\cdot) f(⋅),随机初始化 R R R并训练;
- 固定 R R R,微调 M ( ⋅ ) M(\cdot) M(⋅)和 f ( ⋅ ) f(\cdot) f(⋅)。
2.4 微调和推理
MIL对比学习框架可以深度探索幻灯片级WSI表示的不同区块之间的语义关系。对于最终的幻灯片级预测,将使用标记WSI来微调框架。在这个阶段, f ( ⋅ ) f(\cdot) f(⋅)的输出维度由128变为类别数。微调依然包含三个阶段,不同的是, f ( ⋅ ) f(\cdot) f(⋅)会后接softmax以获得置信度得分,然后该得分的增量作为 R R R的奖励,即 r ^ t = p ^ t − p ^ t − 1 \hat{r}_t=\hat{p}_t-\hat{p}_{t-1} r^t=p^t−p^t−1,其中 ( ⋅ ^ ) (\hat{\cdot}) (⋅^)表示微调中相应的变量, p ^ \hat{p} p^是softmax预测概率。
MuRCL的测试过程与微调过程一致:
- 给定测试WSI-Fbag,随机采样WSI-Fset, M ( ⋅ ) M(\cdot) M(⋅)提供初始状态;
- R R R确定WSI-Fset, M ( ⋅ ) M(\cdot) M(⋅)和 f ( ⋅ ) f(\cdot) f(⋅)处理。在这个阶段,代理迭代生成状态向量,代理的最后一次输出作为WSI-Fset提案,然后输出分类预测;
对比损失可以拉近类别相近特征的距离以及加大类别不同特征的距离。
3 实验
3.1 数据集
- Camelyon16:乳腺癌检测数据,包含270个训练WSI和129个测试WSI。任务目的为是否癌症,或者定位。预处理后包含2.7百万个区块,平均每个包6881;
- TCGA-Lung:包含TCGA-LUSC和TCGA-LUAD两个子项目,总计1041个诊断图片,529LUAD和512LUSC,用于WSI子类型分类和存活分析。预处理后,平均每个包11540;
- TCGA-Kidney:包含TCGA-KICH、TCGA-KIRC,以及TCGA-KIRP三个子项目,累计734张WSI,其中KICH92、KIRC411,以及KIRP231,这也适用于多分类。
- TCGA-Esca:包含两个子类,总计156WSI,其中90个鳞状细胞癌和66个腺癌。
预处理:
所有实验使用
20
×
20\times
20×量级的WSI,每一个WSI被裁剪为
256
×
256
256\times256
256×256的多个区块,组织区域低于
35
%
35\%
35%的区块将被抛弃。对于Camelyon16,训练集
20
%
20\%
20%作为验证集。对于TCGA,训练:验证:测试=3:1:1。
3.2 实现细节和评估指标
- 每个区块通过预训练模型嵌入为512维的特征向量,其中Camelyon16和TCGA-Lung使用SimCLR和ResNet18,余下只使用ResNet18;
- WSI特征包首先使用Kmeans聚类为10簇,每个簇 C k C_k Ck的采样率设置为 1024 / u 1024/u 1024/u,其中 u u u是包特征的数量;
- 批次大小 N = 128 N=128 N=128,温度参数 τ = 1 \tau=1 τ=1;
- 训练的第一阶段使用Adam,其中 M ( ⋅ ) M(\cdot) M(⋅)的学习率设置为1e-4, f ( ⋅ ) f(\cdot) f(⋅)的为1e-5,权重衰减设置为1e-5;
- 第二阶段代理 R R R的初始学习率设置为1e-5;
- 第三阶段 M ( ⋅ ) M(\cdot) M(⋅)和 f ( ⋅ ) f(\cdot) f(⋅)使用Adam联合优化,学习率分别设置为5e-5和1e-5,权重衰减不变;
- 三个阶段的训练批次分别为100、30,以及100;
- 评价指标使用ACC、AUC,以及F1。