论文阅读 (83):MuRCL: Multi-instance Reinforcement Contrastive Learning for Whole Slide Image (医学图像)

1 概述

1.1 题目

2022用于全幻灯片分类的多示例强化对比学习 (MuRCL: Multi-instance reinforcement contrastive learning for whole slide image classification)

1.2 摘要

多示例学习 (MIL) 广泛应用于自动全幻灯片图像 (Whole slide image, WSI) 分析,其处理策略可以分为:

  1. 实例特征提取;
  2. 特征聚合。

然而,由于幻灯片级别标签的弱监督性,MIL模型的训练过程通常会呈现严重的过拟合。在这种情况下,从有限的幻灯片级别标注的数据中发掘更多的信息是至关重要的。

与已有的方法不同,本文着重于探索不同实例 (区块) 之间的潜在关系,而非提升实例特征的提取,以提高模型的泛化能力。具体地,MuRCL从以下几个角度处理问题:

  1. 训练自监督管理器,然后基于WSI幻灯片级别标签微调。这个过程被称为对比学习 (Contrastive learning, CL),其基于WSI中相同的区块级特征包构建了正/负判别特征集
  2. 为了加速CL训练,设计了一个基于强化学习的代理,根据幻灯片级特征聚合的在线奖励 (Online reward) 以逐步更新辨别特征集的选择。然后使用标记的WSI数据来更新模型和习得特征,并获取最终的WSI分类。

实验在三个公开的WSI分类数据集上进行,包括Camelyon16、TCGA-Lung,以及TCGA-Kidney。实验结果验证了MuRCL的性能,其在TCGA-Lung数据集的效果尤为明显。

图1展示了MuRCL与一般的MIL的区别。

图1:MuRCL与一般MIL方法的对比:(a) 一般方法基于图像级标签,实现区块提取、区块选择,以及区块聚合;(b) MuRCL利用不同补丁的内在关系,通过最大化同一WSI的两个判别特征集之间的一致性来训练,然后是微调WSI的预测

1.3 代码

Torch:https://github.com/wwu98934/MuRCL

1.4 引用

@article{Zhu:2022:113,
author		=	{Zhong Hang Zhu and Le Quan Yu and Wei Wu and Rong Shan Yu and De Fu Zhang and Lian Sheng Wang},
title		=	{{MuRCL}: {M}ulti-instance reinforcement contrastive learning for whole slide image classification},
journal		=	{{IEEE} Transactions on Medical Imaging},
pages		=	{1--13}
year		=	{2022},
doi			=	{10.1109/TMI.2022.3227066}
}

2 方法


  图2:MuRCL的自监督学习过程:(a) MuRCL:给定一个WSI的输入特征包 (WSI-Fbag) x x x,两个RL-MIL分支的输出是用于对比损失的正对;(b) RL-MIL:一个奖励导向的代理 R \boldsymbol{R} R被用于从 x x x选择辨别特征集 (WSI-Fset) x ~ \tilde{x} x~ x ~ \tilde{x} x~首先被随机初始化,然后通过 R \boldsymbol{R} R更新。 x ~ \tilde{x} x~传递给MIL聚合器,以生成特征嵌入 v v v,并通过映射头 f ( ⋅ ) f(\cdot) f()输出 p p p,其将用于计算对比损失;( c \text{c} c) RL选择 s t s_t st用于从由 s t k s_t^k stk构成的关联图的每个簇中选出索引特征

图2(a)展示了MuRCL的总体框架,其通过代理构建用于对比学习的正/负对,并从输入特征包中选择两个独立的辨别集。然后,MIL汇聚器 M ( ⋅ ) M(\cdot) M()和投影头 f ( ⋅ ) f(\cdot) f()使用对比损失来最大化正辨别集之间的一致性。MuRCL的每个分支是一个RL-MIL。

图2(b)展示了RL-MIL的序列决策过程。给定输入包,RL-MIL迭代生成一个辨别集的序列,并使用MIL聚合器和投影头输出特征向量序列。特别地,在每一步,代理 (辨别集网络) 确定一组特征索引。然后,通过组合每个包簇的索引特征来构建下一步的判别集,如图2( c \text{c} c)。最后利用WSI标签微调自监督预训练模型,以获取用于预测的表示。

2.1 多示例对比学习

多示例对比学习使用WSI-Fbag作为输入,其中每个包包含了在ImageNet上训练的ResNet18处理后的区块级嵌入CL的一个关键步骤是构建用于训练的逻辑正/负对 (即语义相似/不相似实例)。与已有的基于图像增强的策略不同,我们从每个WSI-Fbag中采样不同的WSI辨别集合 (更小的WSI-Fset),并构建用于CL训练的基于集合的正/负对。

每个WSI-Fset是从WSI-Fbag的多个特征簇得到的子集的组合。特别地,给定WSI0Fbag x x x,首先使用聚类算法,例如kMeans将其划分为 K K K个簇 C k ( k ∈ [ 1 , 2 , … , K ] ) C_k(k\in[1,2,\dots,K]) Ck(k[1,2,,K])。通过从 C k C_k Ck采样到一个子集,WSI-Fset x ~ \tilde{x} x~是所有这些子集的拼接。每一个簇的采样率保持一致,因此 x ~ \tilde{x} x~有常量个实例嵌入。

随后, x ~ \tilde{x} x~被传递给MIL聚合器 M ( ⋅ ) M(\cdot) M()以及投影头 f ( ⋅ ) f(\cdot) f(),以获取WSI级别特征嵌入 p p p。值得注意的是这里有多种采样策略,因此能够获得不同的 x ~ \tilde{x} x~以及相应的嵌入。不同的 x ~ \tilde{x} x~可以看作是相同WSI的不同视角,可被用于构建CL中的正对。

{ x ~ n } n = 1 N \{\tilde{x}_n\}_{n=1}^N {x~n}n=1N表示 N N N个WSI-Fset的,其中 x ~ i \tilde{x}_i x~i x ~ j \tilde{x}_j x~j采样自同一个 x x x,其他的采样自不同的WSI-Fbag。然后,CL损失计算为:
L i , j = − log ⁡ exp ⁡ ( s i m ( p i , p j ) / τ ) ∑ n = 1 N 1 ( n ≠ i ) exp ⁡ ( s i m ( p i , p n ) / τ ) , (1) \tag{1} L_{i,j}=-\log\frac{\exp(sim(p_i,p_j)/\tau)}{\sum_{n=1}^N1(n\neq i)\exp(sim(p_i,p_n)/\tau)}, Li,j=logn=1N1(n=i)exp(sim(pi,pn)/τ)exp(sim(pi,pj)/τ),(1)其中 τ \tau τ是温度参数、 s i m ( ⋅ , ⋅ ) sim(\cdot,\cdot) sim(,)表示两个向量之间的余弦相似度、 1 ( n ≠ i ) ∈ { 0 , 1 } 1(n\neq i)\in\{0,1\} 1(n=i){0,1}是一个指示函数,即如果条件满足则为 1 1 1,反之为 0 0 0。本文使用NT-Xent作为目标函数,以最大化正对之间的相似度和最小化负对之间的相似度。在这种情况下,MIL聚合器将能够习得用于准确分类的聚合知识。

2.2 基于RL的辨别集构建

如前所述,MuRCL中的一个重要步骤是如何构建WSI辨别集 (WSI-Fset)。因此我们提出了一个新颖的基于强化学习的策略,记为RL-MIL,其基于WSI-Fbag构建WSI-Fset,且其中的一个WSI代理 R R R (循环神经网络) 通过强化学习来训练。如图2(b)所示,WSI-Fset的构建过程是一个序列决策。在每一步中,MIL聚合器 M ( ⋅ ) M(\cdot) M()映射头 f ( ⋅ ) f(\cdot) f()使用WSI-Fset作为输入,获取相应的语义预测 p p p。与此同时, R R R利用输入的特征向量 v v v生成另一个WSI-Fset提案 s s s

特别地,对于输入 x x xRL-MIL迭代地生成一个序列 { x ~ 0 , … , x ~ t , … , } \{\tilde{x}_0,\dots,\tilde{x}_t,\dots,\} {x~0,,x~t,,}。在第 t t t次迭代时, M ( ⋅ ) M(\cdot) M()收到当前的 x ~ t \tilde{x}_t x~t,并输出特征向量 v t v_t vt,并通过 f ( ⋅ ) f(\cdot) f()获取幻灯片级嵌入 p t p_t pt。然后, p t p_t pt被传递给公式1。与此同时,WSI-Fset提案代理 R R R v t v_t vt作为输入来决定下一次 x ~ t + 1 \tilde{x}_{t+1} x~t+1的选择特征索引的行为 s t + 1 s_{t+1} st+1。随后, x ~ t + 1 \tilde{x}_{t+1} x~t+1 x x x中选择,如图2( c \text{c} c)。如此,迭代一轮

M ( ⋅ ) M(\cdot) M()使用ABMIL和CLAM, f ( ⋅ ) f(\cdot) f()和提案代理 R R R均使用循环神经网络,这样它们就可以通过在其中分别保持隐藏状态 h t − 1 R h^R_{t -1} ht1R h t f h^f_t htf来探索所有先前输入的信息。注意本文的策略并不是一个RL框架,而是利用RL中的优化策略来优化自己的方法。在MuRCL中,RL用于辅助MIL辨别集构建。在集合构建阶段,代理多次扫描WSI以定位辨别特征,计算奖励并更新下次决策的代理。因此,本文借鉴的是强化学习的思想。

2.2.1 RL选择

在RL辨别集合构建的每一步中,特征索引从WSI-Fbag中选择WSI-Fset以供后续步骤。为了加速代理生成一个空间联系WSI-Fset提案,我们根据 x x x的聚类标签将其特征从排序,即具有相同聚类标签的特征将被分配相邻的索引。随后,对于每一个簇 C k C_k Ck中的特征,沿着它们相应区块的坐标重排列它们。重新排序后的WSI-Fbag被称为关联图,如图2( c \text{c} c)。然后,根据代理 R R R预测的动作 s s s从关联图中合成WSI-Fset,其中我们将动作 s s s制定为重新排列的集群的一组特征索引。特别地,第 t t t步的特征索引向量 s t ∈ R K ∗ 1 s_t\in\mathbb{R}^{K*1} stRK1 R R R在上一步获取的,其元素 s t k s_t^k stk表示第 k k k个 重排列簇的特征索引。因此,对于簇 C k C_k Ck,从第 s t k s_t^k stk个特征开始采样一个长度为采样率乘以特征维度的序列,然后,拼接不同簇的序列得到 x ~ t \tilde{x}_t x~t

2.2.2 奖励

WSI-Fset提案代理 R R R使用策略梯度法训练。在训练阶段,奖励函数用于控制 R R R的优化方向。本文利用WSI-Fset两个正对之间的相似性作为奖励来引导代理 R R R定位信息化特征。特别地,在第 t t t步,奖励函数为:
r i , j ; t = s i m ( p i ; t − 1 , p j ; t − 1 ) − s i m ( p i ; t − p j ; t ) , (2) \tag{2} r_{i,j;t}=sim(p_{i;t-1},p_{j;t-1})-sim(p_{i;t}-p_{j;t}), ri,j;t=sim(pi;t1,pj;t1)sim(pi;tpj;t),(2)其中 p i p_i pi p j p_j pj分别来自 x ~ i \tilde{x}_i x~i x ~ j \tilde{x}_j x~j。通过正WSI-Fset之间余弦距离,这将促使MIL模型通过最小化CL来关注于潜在的汇聚知识。

2.2.3 辨别集混合

为了在MIL聚合器训练时引入更多的扰动,本文使用了一种高效的特征汇聚策略,称为集合混合 (set-mixup),以增加WSI-Fset的多样性。对于训练批次中的WSI-Fset,混合 x ~ l \tilde{x}_l x~l x ~ q \tilde{x}_{q} x~q中生成增强表示 x ‾ q \overline{x}_q xq
x ‾ q = λ x ~ q + ( 1 − λ ) x ~ l , (3) \tag{3} \overline{x}_q=\lambda\tilde{x}_q+(1-\lambda)\tilde{x}_l, xq=λx~q+(1λ)x~l,(3)其中 λ \lambda λ是一个从 U ( α , 1.0 ) U(\alpha,1.0) U(α,1.0)分布中的采样,本文设置 α = 0.9 \alpha=0.9 α=0.9。这种混合用以增强语义概念学习。

2.3 RL-MIL的训练策略

图2(a)所示,对比学习框架有两个分支,而两个分支中 M ( ⋅ ) M(\cdot) M() f ( ⋅ ) f(\cdot) f(),以及 R R R的参数是共享的。在一个训练批次中,首先从WSI-Fbag中随机选择两个WSI-Fset作为初始化的正对。然后,正对将使用RL策略训练的代理重构。为了清晰表示,本文使用下表 ( ⋅ ) i (\cdot)_i ()i ( ⋅ ) j (\cdot)_j ()j表示两个分支 { ( ⋅ ) t } t = 0 T \{(\cdot)_t\}_{t=0}^T {()t}t=0T 表示每个分支生成的时间序列。这里 T = 5 T=5 T=5,表示RNN在每个训练分支上运行五次。在开始时,正对表示为 x ~ i ; 0 \tilde{x}_{i;0} x~i;0 x ~ j ; 0 \tilde{x}_{j;0} x~j;0,并通过 M ( ⋅ ) M(\cdot) M()生成两个不同的特征向量 v i ; 0 v_{i;0} vi;0 v j ; 0 v_{j;0} vj;0。与此同时,WSI级特征嵌入 p i ; 0 p_{i;0} pi;0 p j ; 0 p_{j;0} pj;0被计算并用以计算CL损失 L 0 = L i , j ; 0 ( p i ; 0 , p j ; 0 ) L_0=L_{i,j;0}(p_{i;0},p_{j;0}) L0=Li,j;0(pi;0,pj;0),其中来自不同WSI-Fbag的WSI-Fset被看作是负对。接下来执行第一次迭代,通过将 v i ; 0 v_{i;0} vi;0 v j ; 0 v_{j;0} vj;0作为 R R R的初始状态来生成新的正对,对于新的正对,又开始同样的计算。在五次迭代中, f ( ⋅ ) f(\cdot) f()被同步处理,其两个分支的输出 { p i ; t , p j ; 0 } t = 10 5 \{p_{i;t},p_{j;0}\}_{t=10}^5 {pi;t,pj;0}t=105用于计算CL损失: L t = ∑ t = 0 5 L i , j ( p i ; t , p j ; t ) L_t=\sum_{t=0}^5L_{i,j}(p_{i;t},p_{j;t}) Lt=t=05Li,j(pi;t,pj;t)。此外, R R R需要最大化奖励 ∑ t = 1 5 γ t − 1 r i , j ; t ( p i ; t , p j ; t ) \sum_{t=1}^5\gamma^{t-1}r_{i,j;t}(p_{i;t},p_{j;t}) t=15γt1ri,j;t(pi;t,pj;t),其中 γ = 0.1 \gamma=0.1 γ=0.1;使用后了两种采样策略从 N N N个WSI-Fbag中生成了 2 N 2N 2N个WSI-Fset。该过程如算法1

训练过程包含三个阶段:

  1. 随机采样WSI-Fset来训练 M ( ⋅ ) M(\cdot) M() f ( ⋅ ) f(\cdot) f(),该阶段用于确保模型能够处理任意大小的序列;
  2. 固定 M ( ⋅ ) M(\cdot) M() f ( ⋅ ) f(\cdot) f(),随机初始化 R R R并训练;
  3. 固定 R R R,微调 M ( ⋅ ) M(\cdot) M() f ( ⋅ ) f(\cdot) f()

2.4 微调和推理

MIL对比学习框架可以深度探索幻灯片级WSI表示的不同区块之间的语义关系。对于最终的幻灯片级预测,将使用标记WSI来微调框架。在这个阶段, f ( ⋅ ) f(\cdot) f()的输出维度由128变为类别数。微调依然包含三个阶段,不同的是, f ( ⋅ ) f(\cdot) f()会后接softmax以获得置信度得分,然后该得分的增量作为 R R R的奖励,即 r ^ t = p ^ t − p ^ t − 1 \hat{r}_t=\hat{p}_t-\hat{p}_{t-1} r^t=p^tp^t1,其中 ( ⋅ ^ ) (\hat{\cdot}) (^)表示微调中相应的变量, p ^ \hat{p} p^是softmax预测概率。

MuRCL的测试过程与微调过程一致:

  1. 给定测试WSI-Fbag,随机采样WSI-Fset, M ( ⋅ ) M(\cdot) M()提供初始状态;
  2. R R R确定WSI-Fset, M ( ⋅ ) M(\cdot) M() f ( ⋅ ) f(\cdot) f()处理。在这个阶段,代理迭代生成状态向量,代理的最后一次输出作为WSI-Fset提案,然后输出分类预测;

对比损失可以拉近类别相近特征的距离以及加大类别不同特征的距离。

3 实验

3.1 数据集

  1. Camelyon16:乳腺癌检测数据,包含270个训练WSI和129个测试WSI。任务目的为是否癌症,或者定位。预处理后包含2.7百万个区块,平均每个包6881;
  2. TCGA-Lung:包含TCGA-LUSCTCGA-LUAD两个子项目,总计1041个诊断图片,529LUAD和512LUSC,用于WSI子类型分类和存活分析。预处理后,平均每个包11540;
  3. TCGA-Kidney:包含TCGA-KICHTCGA-KIRC,以及TCGA-KIRP三个子项目,累计734张WSI,其中KICH92、KIRC411,以及KIRP231,这也适用于多分类。
  4. TCGA-Esca:包含两个子类,总计156WSI,其中90个鳞状细胞癌和66个腺癌。

预处理
所有实验使用 20 × 20\times 20×量级的WSI,每一个WSI被裁剪为 256 × 256 256\times256 256×256的多个区块,组织区域低于 35 % 35\% 35%的区块将被抛弃。对于Camelyon16,训练集 20 % 20\% 20%作为验证集。对于TCGA,训练:验证:测试=3:1:1。

3.2 实现细节和评估指标

  1. 每个区块通过预训练模型嵌入为512维的特征向量,其中Camelyon16和TCGA-Lung使用SimCLR和ResNet18,余下只使用ResNet18;
  2. WSI特征包首先使用Kmeans聚类为10簇,每个簇 C k C_k Ck的采样率设置为 1024 / u 1024/u 1024/u,其中 u u u是包特征的数量;
  3. 批次大小 N = 128 N=128 N=128,温度参数 τ = 1 \tau=1 τ=1
  4. 训练的第一阶段使用Adam,其中 M ( ⋅ ) M(\cdot) M()的学习率设置为1e-4, f ( ⋅ ) f(\cdot) f()的为1e-5,权重衰减设置为1e-5;
  5. 第二阶段代理 R R R的初始学习率设置为1e-5;
  6. 第三阶段 M ( ⋅ ) M(\cdot) M() f ( ⋅ ) f(\cdot) f()使用Adam联合优化,学习率分别设置为5e-5和1e-5,权重衰减不变;
  7. 三个阶段的训练批次分别为100、30,以及100;
  8. 评价指标使用ACC、AUC,以及F1。
多智能体深度强化学习,用于群组分发中的任务卸载。 多智能体深度强化学习是一种强化学习的方法,可以应用于群组分发中的任务卸载问题。在群组分发中,有多个智能体,每个智能体都拥有一定的处理能力和任务需求。任务卸载是指将任务从一个智能体卸载到其他智能体上进行处理,以实现任务优化和系统性能的提升。 多智能体深度强化学习通过使用深度神经网络来构建智能体的决策模型,并基于强化学习框架进行智能体的训练和决策制定。在任务卸载中,每个智能体的状态可以由其当前的任务负载、处理能力和通信延迟等因素来表示。智能体的动作则是选择是否将任务卸载到其他智能体上进行处理。通过与环境交互,智能体可以通过强化学习来调整其决策策略,以优化任务卸载过程中的系统性能。 在多智能体深度强化学习中,可以使用任务奖励来指导智能体的行为。例如,当一个智能体选择将任务卸载给处理能力更高的智能体时,可以给予奖励以鼓励这种行为。同时,如果任务卸载导致较高的通信延迟或任务负载不均衡等问题,可以给予惩罚以避免这些不良的决策。 通过多智能体深度强化学习,可以实现群组分发中的任务卸载优化。智能体可以通过学习和适应来提高系统的整体性能和效率,从而实现任务分配的最优化。这种方法可以应用于各种领域,例如云计算、物联网和机器人协作等多智能体系统。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值