论文119:Sparse multi-modal graph transformer with shared-context processing (2023, CVPR)

1 要点

题目:用于千兆像素图像表示学习的共享上下文处理稀疏多模态图变换器 (Sparse multi-modal graph transformer with shared-context processing for representation learning of giga-pixel images)

代码:https://github.com/raminnakhli/amigo

研究目的:
传统的多示例学习 (MIL) 方法在处理WSI时忽略了单个细胞内的显式信息,并且依赖于大量数据,容易过拟合。此外,MIL方法主要关注图像块,限制了模型对单个细胞的分辨率,并且缺乏细胞间相互作用的信息。因此,提出了一种利用组织内细胞图层次结构来提供单个WSI表示的新方法,同时能够动态地在细胞级和组织级信息之间聚焦。

关键技术

  1. 共享上下文处理
    在不同模态之间共享处理步骤,以利用不同染色的WSI之间的共享上下文信息;
  2. 多模态图神经网络
    使用GNN处理从WSI生成的细胞图,关注细胞级别的信息及其交互;
  3. 稀疏处理
    在输入图的特征和邻接矩阵上应用掩码操作来减少计算复杂性;
  4. 批量截尾部分(BCP)技术
    为了解决生存预测中截尾数据 (censored data) 导致的训练问题,提出了一种新的损失函数,通过在批次中平衡截尾和非截尾数据的比例来改善梯度信号;

数据集

  1. InUIT:包含1,600个组织微阵列 (TMA) 核心,来自188名患者的高级别浆液性卵巢癌队列,使用Ki67、CD8和CD20生物标记物染色,不同的染色视作一个模态;
  2. MIBC:包含585个TMA核心,来自58名患者的肌浸润性膀胱癌队列,使用Ki67、CK20和P16生物标记物染色。

注意:本文不同于已有的WSI分类,其主要是做WSI图像的生存预测,且不同的模态是因为每个患者的细胞可以在不同的染色下观测。因此,对于WSI分类或者多模态的小伙伴,可以只关注其单模态学习和模态融合,并获取WSI表示的部分。

2 方法

2.2 问题声明

{ x n , i m ∣ n = 0 , … , N ; m = 0 , … , M ; i = 0 , … , C ( n , m ) } \{x_{n,i}^m | n = 0, \dots, N; m = 0, \dots, M; i = 0, \dots, C(n, m)\} {xn,imn=0,,N;m=0,,M;i=0,,C(n,m)}表示作为数据集中的WSI集合,其中 n n n是患者编号、 m m m是模态编号,以及 i i i是图像标识符。在这种设置下, N N N表示患者总数、 M M M是模态总数,以及 C ( n , m ) C(n, m) C(n,m)表示患者 n n n在模态 m m m下可用图像的数量。我们的目标是预测每个患者的预估存活时间,也称为结果。更具体地说,使用患者所有可用的图像 (跨不同模态) 来获得基于此的统一表示 R n ∈ R 1 × d R_n \in \mathbb{R}^{1 \times d} RnR1×d,以便预测存活时间。为了避免重复,假设 x n , i m x_{n,i}^m xn,im指的是图像以及从中生成的细胞图。图2展示了所提出方法的总览图。

2.2 多模态共享上下文处理

处理多模态数据的常见策略是使用单独的编码器对每种模态进行编码,如图3a。然而,当处理包含相似上下文的不同模态 (例如,不同染色的细胞图) 时,结合共享和非共享处理步骤可能会有益。特别是,认为对于这种情况,需要一个3步程序 (如图3b):

  1. 使用共享模型从所有模态中提取基本特征,以帮助泛化;
  2. 使用针对每种模态的单独模型执行模态特定分析;
  3. 通过跨所有模态共享的模型统一高级表示。

这种共享上下文处理方法使得可以在低级和高级特征统一的同时,允许中级特征处理的灵活性。

2.3 稀疏多模态细胞图神经网络

2.3.1 模态编码

对于每位患者,其有两种形式的数据:

  1. 来自不同模态的各种图像;
  2. 每种模态内的各种图像;

本文方法的第一阶段处理第二种类型的数据,其涉及特定于每种模态的处理分支:

  1. 每个分支包括一个单模态编码器,其后接一个实例注意力聚合器。给定 { x n , i m ∣ i = 0 , . . . , C ( n , m ) } \{x_{n,i}^m | i = 0, ..., C(n, m)\} {xn,imi=0,...,C(n,m)}作为输入,它生成单个表示向量 R n m ∈ R 1 × d R_{n}^m \in \mathbb{R}^{1 \times d} RnmR1×d
    • 该分支的编码器设计为一个GNN模型,由三个GraphSAGE层组成,每个层后面跟着一个SAGPool。SAGPooling层使模型能够通过选择图中最重要的节点来执行层次池化;
  2. 每个SAGPool后图节点的平均和最大池化嵌入被连接起来,对于不同的池化层,它们被加在一起,并通过一个2层多层感知机 (MLP) 传递;考虑到来自不同模态的输入图具有相似的上下文,它们可以从共享上下文处理中受益;
  3. 为了执行低级特征统一,使用矩阵分解将每个分支的第一层耦合起来。更具体地说,我们的GraphSAGE层遵循:
    h ^ k m = W s W m [ h k m , 1 K ∑ j ∈ N k h j m ] , (1) \tag{1} \hat{h}_{k}^m = W_s W_m [h_{k}^m, \frac{1}{K} \sum_{j \in \mathcal{N}_k} h_{j}^m], h^km=WsWm[hkm,K1jNkhjm],(1)其中 h k m h_{k}^m hkm h ^ k m \hat{h}_{k}^m h^km是节点 k k k在GraphSAGE层之前和之后的嵌入、 W s W_s Ws是跨模态分支共享的权重、 W m W_m Wm是模态分支 m m m的特定权重、 . , . ] ., .] .,.]是拼接操作、 N k \mathcal{N}_k Nk是连接到节点 k k k的节点集,以及 K K K N k \mathcal{N}_k Nk的大小;
    • 难点
      • 主要在于理解 h k m h_k^m hkm到底是什么,其实原文的符号有些混乱,这里的 h k m h_k^m hkm表示给定患者在模态 m m m下的第 k k k个节点, N k \mathcal{N}_k Nk是它的邻居节点的数量,其大小通过 K K K来调控;
      • GraphSAGE层是一个已有的网络,如果不想深入了解的话,只需要知道,它是用于汇聚给定节点,和其邻居节点信息,并得到一个新的表示;
      • SAGPool,简单理解就是以专用于图的池化方法,其利用自注意力来保留图中的重要部分,其输出依然是一个图,具体可以对应图2来看;
  4. 对于每个模态分支的第一个GraphSAGE层,设置 W s W_s Ws为一个可学习的矩阵,对于其他层,它将是一个全1矩阵;
  5. 为了组合患者特定模态内的多个图嵌入,使用实例注意力与实例归一化
    R m , n = InstanceNorm ( ∑ i = 0 C ( n , m ) σ ( W R n , i m ) R n , i m ) , (2) \tag{2} R_{m,n} = \text{InstanceNorm}\left(\sum_{i=0}^{C(n,m)} \sigma(W R_{n,i}^m) R_{n,i}^m\right), Rm,n=InstanceNorm i=0C(n,m)σ(WRn,im)Rn,im ,(2)其中 R n , i m 是 x n , i m R_{n,i}^m是x_{n,i}^m Rn,imxn,im的相应表示、 W W W是一个可学习的矩阵,以及 σ \sigma σ是sigmoid函数。实例注意力层执行共享上下文处理的高级聚合部分,通过跨所有模态共享 W W W
    难点
    • 这里主要是需要知道 R n , i m R_{n,i}^m Rn,im是什么,因为它其实和公式1没法对应。作者目前也是伪开源,所以需要结合图2来看;
    • 可以发现图2用到了多个GraphSAGE层和SAGPool,其学习的结果会交给一个MLP,再传递给公式2。因此, R n , i m R_{n,i}^m Rn,im其实就是每个节点通过图表示学习和MLP后的表示。

2.3.2 跨模态聚合

为了组合不同模态 ( R m , n ) (R_{m,n}) (Rm,n)的表示,采用了Transformer模型。每个注意力头的方程式如下:
H n = softmax ( Q n K n T d ) V n , (3) \tag{3} H_n = \text{softmax}\left(\frac{Q_n K_n^T}{\sqrt{d}}\right) V_n, Hn=softmax(d QnKnT)Vn,(3)其中 Q n , K n , V n Q_n, K_n, V_n Qn,Kn,Vn是将患者 n n n的表示矩阵 c o n c a t { R n 1 , . . . , R n M } concat\{R^1_n, ..., R^M_n\} concat{Rn1,...,RnM}分别输入到全连接层得到的 M × d M×d M×d的矩阵。

最后,所有头被连接起来,通过一个MLP传递,并跨模态平均,以生成患者的嵌入。通过获得对表示的一般理解,Transformer使模型能够进行跨所有染色模态的交叉注意力,以强调最具信息性的特征。

2.3.3 稀疏处理

尽管先前的研究强调了学习输入图的精确拓扑结构对于各种图相关应用的重要性,本文发现所提出的多模态方法对输入数据的稀疏性非常健壮。类似于计算机视觉领域最近的工作 (例如,MAE),利用这一发现进一步减少模型的计算复杂性。具体而言,在每个模态中,对输入图的特征和邻接矩阵执行掩码操作
X ^ = M X , A ^ = M A M T , M = P I , (4) \tag{4} \hat{X} = M X, \hat{A} = M A M^T, M = P\mathcal{I}, X^=MX,A^=MAMT,M=PI,(4)其中 X ∈ R c × d X \in \mathbb{R}^{c \times d} XRc×d是节点的特征矩阵、 A ∈ R c × c A \in \mathbb{R}^{c \times c} ARc×c是图的邻接矩阵、 c c c是图中节点的数量、 I ∈ R 1 × c \mathcal{I} \in \mathbb{R}^{1 \times c} IR1×c是全1矩阵,以及 P ∈ R c × 1 P \in \mathbb{R}^{c \times 1} PRc×1是掩码矩阵,其每个元素来自参数为 1 − s 1-s 1s的伯努利分布,其中 s s s是稀疏比率。随着 s s s的增加, X ^ \hat{X} X^ A ^ \hat{A} A^中的非零元素数量减少,导致后续计算操作的减少,称之为稀疏处理。实验将展示即使在训练期间数据的稀疏比率低至20%,模型也能保持相似的性能。

2.3.4 损失函数和BCP技术

生存预测是一个挑战性的任务,包括估计失败时间 (死亡) 作为一个连续变量。在最大似然估计的术语中,这意味着对于在特定时间失败的受试者,我们必须最大化该受试者相对于其他未失败受试者的失败概率。

考虑 t j t_j tj R ( t j ) R(t_j) R(tj)分别是受试者 j j j的失败时间和至少存活至时间 t j t_j tj的受试者集合。受试者 j j j失败概率计算如下:
P j ( T = t j ∣ R ( t j ) ) = P j ( T = t j ∣ T ≥ t j ) ∑ i : t i ≥ t j P i ( T = t j ∣ T ≥ t j ) . ( 5 ) (5) \tag{5} P_j(T = t_j | R(t_j)) = \frac{P_j(T = t_j | T \geq t_j)}{\sum_{i: t_i \geq t_j} P_i(T = t_j | T \geq t_j)}. \quad (5) Pj(T=tjR(tj))=i:titjPi(T=tjTtj)Pj(T=tjTtj).(5)(5)我们的训练目标是最大化每个 j j j的失败概率。特别地,一个小批量 B B B的总损失的期望值计算计算如下:
L batch = − E i ∼ U ( ⋅ ) [ log ⁡ P i ( T = t i ∣ R ( t i ) ) ] . (6) \tag{6} L_{\text{batch}} = -\mathbb{E}_{i \sim \mathcal{U}(\cdot)}[\log P_i(T = t_i | R(t_i))]. Lbatch=EiU()[logPi(T=tiR(ti))].(6)然而,上述损失有一个实际问题。问题源于公式5中的损失仅对有特定失败时间的受试者定义,并且对于在他们最新的随访中存活状态的受试者未定义 (通常称这类受试者为截尾数据)。因此,截尾受试者由于其未定义的损失 (显式梯度),在公式6的反向传播步骤中不作为单独的数据点提供梯度,这会干扰模型的正确训练。必须注意的是,这些受试者仍然通过非截尾受试者的损失 (公式5的分母) 的隐式梯度参与反向传播。

为了缓解这个问题,重新制定了损失函数:
L batch = − E i ∼ k U C ( . ) + ( 1 − k ) U N ( . ) [ log ⁡ P i ( T = t i ∣ R ( t i ) ) ] . (7) \tag{7} L_{\text{batch}} = -\mathbb{E}_{i \sim k \mathcal{U}_C(.) + (1-k) \mathcal{U}_N(.)}[\log P_i(T = t_i | R(t_i))]. Lbatch=EikUC(.)+(1k)UN(.)[logPi(T=tiR(ti))].(7)其中 U C ( ⋅ ) \mathcal{U}_C(\cdot) UC() U N ( ⋅ ) \mathcal{U}_N(\cdot) UN()分别是截尾和非截尾受试者的均匀分布, k k k来自参数为 α α α的伯努利分布。可以注意到,如果 α α α等于截尾案例的百分比,这两个方程是相等的。然而,我们将展示选择这个参数的适当值可以在截尾数据的隐式和显式梯度之间实现平衡,从而产生最高性能。我们称 α α α为批量截尾部分 (BCP)。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值