MamMIL: Multiple Instance Learning for Whole Slide Images with State Space Models论文总结

MamMIL: Multiple Instance Learning for Whole Slide Images with State Space Models

MamMIL:利用状态空间模型对全切片图像进行多示例学习

研究团队:清华大学、哈尔滨工业大学、深圳大学

一、研究背景

  • 最近,癌症诊断的金标准——病理诊断,通过将Transformer与使用整张全切片图像( WSIs )多示例学习( MIL )框架相结合,取得了优越的性能。

    • WSIs的特点:WSIs由数百亿像素组成,阻碍了精细注释的获取,通常只能获得WSI级别的标签;另一方面,由于图形处理单元(GPU)的内存限制,将大规模的WSI输入到深度学习模型中往往是不可行的。

    • 多示例学习:MIL将每个WSI看成是由从WSI中分割出来的小块组成的,称为实例。在MIL中,特征学习不再在千兆像素的WSI上进行。相反,在每个小实例上进行特征提取,从而解决了GPU内存限制的问题。在实例特征提取后,MIL将所有实例特征聚合得到一个包特征,该特征可以被WSI级别的标签监督,从而使模型训练过程得以进行。

    • 由于示例聚合直接影响包特征的判别性,从而决定了模型的性能,因此许多MIL框架关注于聚合过程

  • 然而,WSIs的千兆像素特性为Transformer中的二次复杂度自注意力机制在MIL中的应用带来了巨大的挑战。现有研究通常采用线性注意力来提高计算效率,但不可避免地带来性能瓶颈

    • 组织间的相互作用在肿瘤进展中至关重要。Shao等人提出了TransMIL,它利用具有自注意力的Transformer来建模成对实例依赖。然而,直接将成千上万个实例的千兆像素WSI输入到Transformer中,由于自注意力的二次复杂性,通常会导致内存溢出问题,从而阻碍模型训练过程。

  • 本文将选择性结构化状态空间模型(即Mamba)与MIL相结合,提出了一种用于WSI分类的MamMIL框架,在保持线性复杂度的同时,实现了实例依赖关系的建模。

    • 问题:将Mamba应用于MIL进行WSI分类存在一些挑战。首先,针对一维( 1D )序列构建Mamba。当将2D WSIs展平为1D序列作为输入时,不可避免地会出现2D空间信息的丢失。其次,Mamba使用"扫描"策略以单向方式计算隐状态。尽管单向扫描对于具有时间序列特征的序列建模是可行的,但对于具有双向成对依赖关系的WSIs可能是低效的。

    • 做法:引入了双向状态空间模型2D上下文感知块,使MamMIL能够学习具有2D空间关系的双向实例依赖关系。

二、本文贡献

  • Mamba作为一个与Transformer性能相当的线性复杂度模型,本文首次将其应用到MIL框架中,用于WSI分类任务,其中每个包由数万个实例组成。

  • 引入了一个双向SSM块来解决Mamba中的单向建模问题。此外,使用2D上下文感知块来避免1D序列中2D空间信息的丢失。

  • 在两个数据集上的实验证明,MamMIL优于SOTA方法,同时获得了比基于Transformer的MIL框架更低的GPU内存占用。本文工作为未来的MIL研究提供了新的架构和方向。

三、研究内容

1.框架概述

图1展示了MamMIL的概述,它主要由实例特征提取阶段和实例特征聚合阶段组成。

  • 在第一阶段,使用滑动窗口将WSI拆分为小的、不重叠的图块作为实例。然后,采用预训练的ResNet50网络来提取实例的特征。经过ReLU激活的可训练线性投影后,实例特征构成一个1D序列,并在最后位置添加一个类标记。然后,将该序列送入第二阶段。

  • 第二阶段由一系列堆叠的MIL - SSM模块组成。每个MIL - SSM模块由双向SSM ( Bi-SSM )块和2D上下文感知块( 2D-CAB )。最后,使用类标记作为WSI分类的包特征

2.实例特征提取

3.实例特征融合

3.1Bi-SSM Block

  • 在MamMIL中,Bi - SSM块是一个关键组件,它利用Mamba以线性复杂度快速挖掘大量实例之间的判别依赖关系。

    • Mamba的主要目的是学习一个从输入实例序列$X = \{ x_i \} ^M_{i = 1}$ 到输出序列$\{ y_i\} ^M _{i = 1}$ 的映射,通过隐状态$\{ h_i \}^M_{i = 1}$ 来实现:

    • 在具有离散输入和权重的深度学习模型中,需要对式1进行离散化处理。通常,与Mamba中一样,$\overline{A}、\overline{B}、\overline{C}$ 采用零阶保持规则离散,时间步长∆为:

      其中A\B\C和∆为可学习的参数

    • 为了增强上下文感知能力,Mamba基于3个可学习的线性投影$l_B、l_C和l_∆$ 将参数B、C和∆与输入序列X相关联:

      式中:P∆为可学习参数。

  • 然而,从公式1中可以看到隐状态hi只与之前的隐状态和当前的输入有关,使得hi以单向"扫描"的方式计算。然而,在WSIs中,任何实例之间都可能存在依赖关系。为了解决这个问题,受ViM的启发,构造了两个SSM来同时模拟前向和后向序列方向,并构造了Bi - SSM块。

    • 对于前向SSM,直接将X馈入SSM,得到输出Y。对于后向SSM,首先对实例特征序列进行求逆,同时将类标记固定在最后一个位置来构造X′。然后,将X′送入另一个SSM得到输出Y′。最后,还原Y′的实例特征部分,得到反向SSM的输出。

  • 为了融合两个SSM的输出,采用了门控机制。

    • 对于第l个MIL - SSM模块的一个输入序列$X^{ ( l )}$ ,首先对$X^{ ( l )}$ 进行洗牌,同时固定类标记的位置,以缓解过拟合。

    • 然后,将洗牌后的$X^{ ( l )}$ 送入两个线性投影,得到输出$\overline{X}和\overline{Z}$ 。将$\overline{X}$ 反相得到$\overline{X'}$ 后,$\overline{X}$和$\overline{X'}$分别送入一维卷积,得到X和X′,作为前向或后向SSM的输入。

    • 同时,$\overline{Z}$被SiLU函数激活,并通过逐元素乘法来门控两个SSM输出的平均值。

    • 最后,将门控序列重新洗牌后作为Bi - SSM的输出。

3.2 2D Context-Aware Block

  • 由于Bi - SSM是在1D序列上进行的,因此实例中的2D空间关系仍然无法被感知。

  • 本文引入了基于金字塔结构卷积的2D - CAB。

    • 首先去除类标记,并将剩余的1D特征序列重塑为2D平方特征图。

    • 如果实例数不可分割,首先将序列循环到最近的完美正方形。

    • 然后,使用3 × 3、5 × 5和7 × 7深度卷积提取2D空间关系。将卷积结果与残差连接求和得到输出X′′。

    • 最后,X′′被压扁回到1D序列。

    • 通过移除填充和复制回类令牌(class token),可以获得包含2D空间信息的1D实例序列作为2D - CAB的输出。

3.3 WSI Classification and Loss Function

  • 在使用L MIL - SSM模块进行区分性实例依赖挖掘和实例特征聚合后,使用类标记作为WSI分类的袋特征,并使用softmax函数激活线性投影。在实验中,本文使用交叉熵损失对模型进行优化。

四、实验

1.数据集

本文使用了两个公开的WSI数据集Camelyon16和BRACS。

  • Camelyon16是一个用于乳腺癌微转移检测的数据集,正式分为270个训练和129个测试WSI。本实验进一步将训练WSI以2:1的比例划分为训练集和验证集,官方测试WSI不进行测试。

  • BRCAS是一个乳腺癌亚型分类数据集,其中每个WSI分为良性、非典型和恶性肿瘤。BRCAS正式划分为395个WSIs的训练集、65个WSIs的验证集和87个WSIs的测试集。采用官方划分进行评价。

2.实验设置

3.对比结果

分类性能比较:MamMIL与其他SOTA方法的分类性能比较结果如表1所示。本文提出的MamMIL方法除了在BRCAS上的精度略低于ABMIL外,在所有指标上都超过了所有SOTA方法。此外,我们可以看到,MamMIL方法在AUC上优于SOTA方法,在BRCAS数据集上比Camelyon16数据集有更大的优势。这可能是由于与Camelyon16相比,BRCAS数据集中有更多的训练WSI。因此,针对大规模数据提出并优化的MamMIL中的SSM块被过拟合的可能性较小。这一现象也暗示了如果在大规模数据集上进行训练,MamMIL可能会表现出更好的性能。

GPU内存足迹比较:除了菲尤等提出的方法,现有的MIL框架要么假设WSIs中的实例遵循i.i.d .假设,使用局部注意力进行实例聚合(例如, ABMIL和CLAM),要么使用全局自注意力(例如, TransMIL)对实例依赖关系进行建模。前一种方法需要少量的GPU内存,但其性能通常有限。后者通常表现更好,但需要更大的GPU内存。相比之下,MamMIL利用Mamba对线性复杂度的长序列建模的能力解决了这一矛盾。如图2所示,对比TransMIL,GPU的内存MamMIL在Camelyon16和BRCAS上的占有率分别降低了65.5 %和71.7 %。此外,MamMIL的内存占用也低于基于S4D的方法,因为基于S4D的方法使用了复值参数,而Mamba只需要实值权重。总之,Mamba的GPU内存占用与i.i.d.假设的框架相当,差异小于1.6 GB。

4.消融实验

提出组件的影响:为了验证提出的组件的有效性,本文进行了消融研究,结果如表2所示。可以看到,所有提出的组件都可以提高模型性能。其中,shuffle操作对MamMIL的表现影响最大。这可能是因为shuffle操作可以防止模型在有限的训练数据中记忆固定的模式,从而避免过拟合。此外,同时使用前向和后向SSM可以提高性能,因为它解决了SSM只能对序列进行单向建模的问题。2D - CAB也可以提高模型性能,证明了使MamMIL能够感知实例的2D空间关系的有效性。

class token位置的影响:考虑到SSM中的隐状态是对先前输入实例的压缩,为了将所有实例信息融合到类令牌中,本文将其添加在特征序列的最后位置。为验证其有效性,进行了消融研究。如表2所示,如果将类标记放在第一位,由于SSMs从头到尾依次扫描和聚合特征,模型几乎无法训练。在中间添加类令牌或不使用类令牌(利用平均池化得到词袋特征)对两个数据集的性能相当,但两者在AUC上的表现均弱于本文提出的方法。综上,将类标记放在最后一个位置可以取得更好的整体效果

  • 13
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值