1 要点
题目:OTSurv: A novel multiple instance learning framework for survival prediction
代码:https://github.com/Y-Research-SBU/OTSurv
研究动机:
生存预测是癌症诊断中至关重要的一步,但传统的生存预测方法面临以下挑战:
- 病理图像异质性问题:WSI图像包含多样化的病理特征,存在全局 (长尾分布的组织形态) 和局部 (具有不确定性的图像块) 异质性,需要模型能够有效捕捉这些差异。
- 现有MIL模型的不足:现有MIL方法未能充分考虑这两种异质性,导致在处理高分辨率的WSI图像时无法捕捉关键的生存相关信息。
研究目的:
旨在通过最优传输 (OT) 方法显式建模生存预测任务中的全局和局部病理异质性,从而提高生存预测的准确性与解释性。
关键技术:
-
最优传输 (Optimal Transport, OT) 框架:
- 通过全局长尾约束捕捉组织形态的长尾分布,防止模式崩溃和过度均匀化;
- 引入局部不确定性约束,优先考虑高置信度的图像块,并通过逐渐提高总传输质量来减少噪声。
-
异质性感知OT解决方案:将传统OT问题转化为不平衡OT问题,通过引入虚拟生存token来吸收未选择的质量,并以高效的矩阵缩放算法解决。
-
数据增强与生存token的学习:通过特征嵌入和特征聚合,利用学到的生存token进行预测。生存toekn的聚合与传输质量相结合,最终生成用于风险预测的生存嵌入。
数据集:TCGA 数据集:
- BLCA:359个样本
- BRCA:868个样本
- LUAD:412个样本
- STAD:318个样本
- CRC:296个样本
- KIRC:340个样本
2 方法
2.1 概述
如图1所示,OTSurv将基于WSI的生存预测重新框架为一个异质性感知的最优传输 (OT) 问题,其由四个模块组成:
- WSI 分解:将一个WSI W W W划分为 N N N个不重叠的图像块 x = { x i } i = 1 N ∈ R N × c × h × w x = \{ x_i \}_{i=1}^{N} \in \mathbb{R}^{N \times c \times h \times w} x={xi}i=1N∈RN×c×h×w,其中 c c c为颜色通道数, h × w h \times w h×w为每个图像块的分辨率;
- 特征嵌入:使用一个冻结的编码器 f e n c f_{enc} fenc提取每个图像块的特征 F = f e n c ( x ) ∈ R N × D F = f_{enc}(x) \in \mathbb{R}^{N \times D} F=fenc(x)∈RN×D,然后通过一个可学习的线性投影将这些特征投影到较低维度的潜在空间,得到实例嵌入 Z = f p r o j ( F ) ∈ R N × d Z = f_{proj}(F) \in \mathbb{R}^{N \times d} Z=fproj(F)∈RN×d;
- 特征聚合:异质性感知的OT模块计算得到最优传输计划 Q = f O T ( C ) ∈ R N × K Q = f_{OT}(C) \in \mathbb{R}^{N \times K} Q=fOT(C)∈RN×K,其中成本矩阵 C ∈ R N × K C \in \mathbb{R}^{N \times K} C∈RN×K通过计算实例嵌入 Z ∈ R N × d Z \in \mathbb{R}^{N \times d} Z∈RN×d和可学习的生存toekn S ∈ R K × d S \in \mathbb{R}^{K \times d} S∈RK×d之间的归一化欧几里得距离得到。通过聚合得到的切片级嵌入为 E = f a g g ( Q ⊤ Z ) ∈ R d E = f_{agg}(Q^\top Z) \in \mathbb{R}^d E=fagg(Q⊤Z)∈Rd,其中 Q ⊤ Z ∈ R K × d Q^\top Z \in \mathbb{R}^{K \times d} Q⊤Z∈RK×d和 f a g g ( ⋅ ) f_{agg}(\cdot) fagg(⋅)是一个线性层;
- 风险预测:聚合后的嵌入 E E E输入到一个线性预测器 f p r e d f_{pred} fpred中,输出最终的风险评分 r = f p r e d ( E ) ∈ R r = f_{pred}(E) \in \mathbb{R} r=fpred(E)∈R。整个模型通过最小化Cox比例风险损失进行优化。
注:这里的步骤和传统的使用token的方法没有太大的区别。而为了将OT问题转换为不平衡OT,以使得已有的求解策略生成,作者在成本矩阵添加了一列虚拟生存token,即全零向量,如图1(b),并在新的 Q Q Q (以新成本矩阵作为输入) 上添加KL散度约束和局部不确定性约束。

2.2 异质性感知OT问题
本节首先概述OT理论,然后讨论所提出的局部和全局OT边际分布约束,这些约束是OTSurv的核心。
2.2.1 OT概述
为了构建异质性感知的OT,我们首先从OT问题出发。具体地,OD的目标是将一个分布最优地传输到另一个分布,并最小化传输成本:
- 给定一个源分布
μ
∈
R
N
×
1
\mu \in \mathbb{R}^{N \times 1}
μ∈RN×1,一个目标分布
ν
∈
R
K
×
1
\nu \in \mathbb{R}^{K \times 1}
ν∈RK×1,以及一个成本矩阵
C
∈
R
N
×
K
C \in \mathbb{R}^{N \times K}
C∈RN×K,目标是确定一个传输矩阵
Q
∈
R
N
×
K
Q \in \mathbb{R}^{N \times K}
Q∈RN×K,使得:
min Q ∈ R + N × K ⟨ Q , C ⟩ F + F 1 ( Q 1 K , μ ) + F 2 ( Q ⊤ 1 N , ν ) (1) \tag{1} \min_{Q \in \mathbb{R}^{N \times K}_+} \langle Q, C \rangle_F + F_1(Q \mathbf{1}_K, \mu) + F_2(Q^\top \mathbf{1}_N, \nu) Q∈R+N×Kmin⟨Q,C⟩F+F1(Q1K,μ)+F2(Q⊤1N,ν)(1)其中 ⟨ ⋅ , ⋅ ⟩ F \langle \cdot, \cdot \rangle_F ⟨⋅,⋅⟩F表示Frobenius内积, Q i j Q_{ij} Qij表示从 μ i \mu_i μi到 ν j \nu_j νj传输的质量。 F 1 F_1 F1和 F 2 F_2 F2用于强制约束 Q Q Q的边际分布。 - 当这些约束被指定为等式时,即 Q 1 K = μ Q \mathbf{1}_K = \mu Q1K=μ, Q ⊤ 1 N = ν Q^\top \mathbf{1}_N = \nu Q⊤1N=ν,该公式退化为Kantorovich的经典OT问题。
- 当 F 1 F_1 F1和 F 2 F_2 F2施加不等式约束时,例如使用KL散度,问题转变为不平衡OT问题。
2.2.2 全局长尾约束
WSI表现出长尾的组织形态,其中主导的模式占主导地位,而稀有的、对生存有重要预测价值的特征则稀缺。如果没有任何约束,传输质量可能会集中在一个生存token上,而强制执行严格的等式约束则迫使质量在所有toekn间均匀分配,这种情况无法捕捉数据的真实长尾特性。
为了解决这个问题,在生存token的质量
Q
⊤
1
N
∈
R
K
×
1
Q^\top \mathbf{1}_N \in \mathbb{R}^{K \times 1}
Q⊤1N∈RK×1上施加KL散度约束,以匹配所需的长尾先验。这个全局约束
C
G
C_G
CG通过确保每个不频繁但至关重要的模式得到适当表示来保持组织多样性,以便准确的生存预测。具体地,在
F
2
(
⋅
)
F_2(\cdot)
F2(⋅)上施加一个KL散度惩罚,同时暂时假设
F
1
(
⋅
)
F_1(\cdot)
F1(⋅)是一个均匀分布,公式为:
min
Q
∈
Π
⟨
Q
,
C
⟩
F
+
λ
KL
(
Q
⊤
1
N
∥
1
K
1
K
)
subject to
Π
=
{
Q
∈
R
+
N
×
K
∣
Q
1
K
=
1
N
1
N
}
(2)
\tag{2} \begin{aligned} \min_{Q \in \Pi} &\langle Q, C \rangle_F + \lambda \text{KL} \left( Q^\top \mathbf{1}_N \parallel \frac{1}{K} \mathbf{1}_K \right)\\ \text{subject to }&\Pi=\{Q\in\mathbb{R}_+^{N\times K}|Q\mathbf{1}_K=\frac{1}{N}\mathbf{1}_N\} \end{aligned}
Q∈Πminsubject to ⟨Q,C⟩F+λKL(Q⊤1N∥K11K)Π={Q∈R+N×K∣Q1K=N11N}(2)其中
λ
\lambda
λ是控制KL散度正则化的加权因子。
2.2.3 局部不确定性感知约束
在公式(2)中, F 1 ( ⋅ ) F_1(\cdot) F1(⋅)约束 Q 1 K = 1 N 1 N Q \mathbf{1}_K = \frac{1}{N} \mathbf{1}_N Q1K=N11N将所有实例视为平等对待,这可能会导致由于初始表示不佳而产生噪声对齐。受到课程学习 (curriculum learning) 启发,一种更有效的方法是从容易的样本开始,然后逐步引入更难的样本:
在成本矩阵中重新公式化选择过程为总质量约束,以消除对敏感超参数调优的需求:
min
Q
∈
Π
⟨
Q
,
C
⟩
F
+
λ
KL
(
Q
⊤
1
N
∥
ρ
K
1
K
)
(3)
\tag{3} \min_{Q \in \Pi} \langle Q, C \rangle_F + \lambda \text{KL} \left( Q^\top 1_N \parallel \frac{\rho}{K} 1_K \right)
Q∈Πmin⟨Q,C⟩F+λKL(Q⊤1N∥Kρ1K)(3)其中
ρ
∈
(
0
,
1
]
\rho \in (0, 1]
ρ∈(0,1]是所选质量的比例,在训练过程中逐渐增加。具体地,Sigmoid ramp-up 函数被用于逐步增加
ρ
\rho
ρ:
ρ
=
ρ
0
+
(
1
−
ρ
0
)
⋅
e
−
5
(
1
−
t
/
(
T
⋅
I
)
)
2
(4)
\tag{4} \rho = \rho_0 + (1 - \rho_0) \cdot e^{-5(1 - t / (T \cdot I))^2}
ρ=ρ0+(1−ρ0)⋅e−5(1−t/(T⋅I))2(4)其中
t
t
t是当前迭代次数、
T
T
T是ramp-up 训练周期,以及
I
I
I是每个周期的迭代次数。
2.2.4 异质性感知OT
公式(3)被称为异质性感知OT,因为它结合了全局和局部的边际分布约束。通过在训练过程中逐步增加 ρ \rho ρ,这种方法在一个统一的 OT框架内无缝集成了实例选择、加权和聚合过程。
2.3 异质性感知的OT求解器
现有的缩放算法旨在解决不平衡的OT问题,而异质性感知OT由于其包含两个定制的边际约束,使得其与标准的不平衡OT不同。尽管如此,通过引入一个虚拟点,可以将其框架转换为一个不平衡OT问题,从而使这些高效的算法能够解决该问题。
如图1(b)所示,引入一个虚拟生存token来吸收
1
−
ρ
1 - \rho
1−ρ的未选择质量,确保与不平衡OT框架的兼容性。具体来说,通过向成本矩阵
C
C
C添加一列零,即虚拟令牌,以对其扩展,得到
C
^
∈
R
+
N
×
(
K
+
1
)
\hat{C} \in \mathbb{R}^{N \times (K+1)}_+
C^∈R+N×(K+1)。
因此,公式(3)可以重写为:
min
Q
^
∈
Φ
⟨
Q
^
,
C
^
⟩
F
+
KL
(
Q
^
⊤
1
N
,
β
,
λ
^
)
subject to
Φ
=
{
Q
^
∈
R
+
N
×
(
K
+
1
)
∣
Q
^
1
K
+
1
=
1
N
1
N
}
(5)
\tag{5} \begin{aligned} \min_{Q̂ \in \Phi} &\langle Q̂, \hat{C} \rangle_F + \text{KL} \left( Q̂^\top \mathbf{1}_N,\beta, \hat{\lambda}\right)\\ \text{subject to }& \Phi=\left\{ Q̂ \in \mathbb{R}^{N \times (K+1)}_+ \mid Q̂ \mathbf{1}_{K+1} = \frac{1}{N} \mathbf{1}_N \right\} \end{aligned}
Q^∈Φminsubject to ⟨Q^,C^⟩F+KL(Q^⊤1N,β,λ^)Φ={Q^∈R+N×(K+1)∣Q^1K+1=N11N}(5)其中
C
^
=
[
C
,
0
N
]
,
β
=
[
ρ
K
1
K
1
−
ρ
]
,
λ
^
=
[
λ
1
K
+
∞
]
,
KL
^
(
Q
^
⊤
1
N
,
β
,
λ
^
)
=
∑
i
=
1
K
+
1
λ
^
i
[
Q
^
⊤
1
N
]
i
log
[
Q
^
⊤
1
N
]
i
β
i
.
(6)
\tag{6} \begin{aligned} &\hat{C} = [C, \mathbf{0}_N], \quad \beta =\left[ \begin{aligned} \frac{\rho}{K} \mathbf{1}_K\\ 1 - \rho \end{aligned} \right], \quad \hat{\lambda} =\left[ \begin{aligned} \lambda \mathbf{1}_K\\ +\infty \end{aligned} \right],\\ &\hat{\text{KL}}\left( \hat{Q}^\top \mathbf{1}_N, \beta, \hat{\lambda} \right) = \sum_{i=1}^{K+1} \hat{\lambda}_i \left[ \hat{Q}^\top \mathbf{1}_N \right]_i \log \frac{ \left[ \hat{Q}^\top \mathbf{1}_N \right]_i }{ \beta_i }. \end{aligned}
C^=[C,0N],β=
Kρ1K1−ρ
,λ^=[λ1K+∞],KL^(Q^⊤1N,β,λ^)=i=1∑K+1λ^i[Q^⊤1N]ilogβi[Q^⊤1N]i.(6)其中加权的KL散度用于约束虚拟生存令牌准确地吸收
1
−
ρ
1 - \rho
1−ρ的未选择质量 (即未被传输的部分)。通过这种方法,能够将原始问题转换为一个可以通过现有算法解决的不平衡OT问题。
1970

被折叠的 条评论
为什么被折叠?



