Attention-based Deep Multiple Instance Learning
基于注意力机制的深度多示例学习
时间:2023/06/05
摘要
-
多示例学习(MIL)定义为学习包标签的伯努利分布,其中包标签的概率由神经网络完全参数化。
-
提出了一种基于NN的置换不变聚合算子,对应注意力机制。
-
该算子提供了每个示例对包标签的贡献的深入了解
-
实验结果,在banchmark数据集上取得与最佳MIL方法相当的性能,在不牺牲可解释性的情况下,在MNIST数据集和两个真实组织病理学数据集上优于其他方法。
引言
存在问题:
MIL的任务中,发现关键示例(Key instance)是许多MIL应用领域的一个重要挑战,例如医学图像中发现关键示例可以为临床实践提供较高的价值。对于分类问题,基于嵌入的MIL方法可以提供相对较高的准确率,但其可解释性就比较弱;基于示例的MIL方法的可解释性较高,但是其准确率则较低,且通常情况下MIL方法在示例级存在分歧。
目标: 将可解释性纳入MIL方法并增加其灵活性。
提出方案:
使用包标签的伯努利分布来建立MIL模型,并通过优化对数似然函数来训练它。
Fundamental Theorem of Symmetric Functions(对称函数的基本定理)——置换不变聚合函数
一个通用程序:
第一步:实例到低维嵌入的变换
第二步:置换不变(对称)聚合函数
第三步:包概率的最终转换
建议采用神经网络(卷积层+全连接层)对所有变换进行参数化。目的是为了使用神经网络的方法,允许通过优化无约束目标函数以端到端的方式训练模型。增加了灵活性。
建议用可训练加权平均代替广泛使用的置换不变算子(Max池化和Mean池化)。加权平均的权值由双层神经网络给出,即采用注意力机制为示例分配权重。值得注意的一点,注意力机制可以帮助我们找到关键示例,这些实例可以进一步用于突出可能的ROI。(投资回报率 (ROI)???)
在图像数据集中,我们提供了经验证据,证明我们的模型可以指示关键实例。
Methodology
MIL
1.基于标准MIL假设公式化MIL
Y = { 0 , i f f ∑ k y k = 0 , 1 , o t h e r w i s e . (1) Y= \begin{cases} 0,&iff\sum_ky_k=0,\\ 1,&otherwise. \end{cases} \tag{1} Y={0,1,iff∑kyk=0,otherwise.(1)
意味着MIL模型是permutation-invariant(置换不变的)(对称的)
2.采用最大算子表示:
Y = max k y k (2) Y=\max_k{y_k}\tag{2} Y=kmaxyk(2)
学习一个试图基于最大超实例标签来优化目标的模型是有问题的
原因:
-
所有基于梯度的学习方法都会遇到梯度消失的问题。
-
只有当使用实例级分类器时,此公式才适用
本文建议通过优化对数似然函数来训练MIL模型,其中包标签服从参数为 θ ( X ) ∈ [ 0.1 ] \theta(X)\in [0.1] θ(X)∈[0.1]的伯努利分布,即包含示例集X的包,标签Y=1的概率。
MIL方法
在MIL设置中,袋概率 θ ( X ) \theta(X) θ(X)必须是置换不变的,因为我们既不假设袋内实例的排序也不假设其依赖性。---->可以从对称函数的基本定理考虑MIL问题。
定理1:对于一组示例 X X X,其得分函数 S ( X ) ∈ R S(X)\in \mathbb{R} S(X)∈R是对称函数(即对 X X X中的元素置换不变),当且仅当其可以分解为以下形式:
S ( X ) = g ( ∑ x ∈ X f ( x ) ) (3) S(X)=g(\sum_{x\in X}f(x))\tag{3} S(X)=g(x∈X∑f(x))(3)
其中g和f是适合的变换。
定理2:采用max代替sum,提供近似分解。
∣ S ( X ) − g ( max x ∈ X f ( x ) ) ∣ < ϵ (4) |S(X)-g(\max_{x\in X}f(x))|<\epsilon\tag{4} ∣S(X)−g(x∈Xmaxf(x))∣<ϵ(4)
定理1和定理2的区别在于,前者是普遍分解,而后者提供了任意近似。
通用的三步方法:
(i)使用函数f的实例变换
(ii)使用对称(置换不变)函数σ的变换实例的组合
(iii)使用函数g的f变换的组合实例的变换。
最后,分数函数的表现力依赖于f和g的函数类的选择。
MIL问题公式:
得分函数 S ( X ) S(X) S(X)==》概率
置换不变函数 σ \sigma σ==》MIL池化
函数 f 、 g 和 σ f、g和\sigma f、g和σ的选择决定了对标签概率建模的具体方法
两种MIL方法:
-
The instance-level approach:
-
转换方法 f f f:示例级分类器,输出示例的得分。
-
MIL池化:对示例得分进行聚合得到 θ ( X ) \theta(X) θ(X)。
-
函数 g g g:恒等函数。
-
-
The embedding-level approach:
-
转换方法 f f f:将示例嵌入至低维度。
-
MIL池化:将所有的低维示例嵌入成包表示。
-
通过包级分类器对包表示进行分类输出 θ ( X ) \theta(X) θ(X)
-
MIL with Neural Networks
使用神经网络参数化所有变换使得整个方法可以任意灵活,并且可以通过反向传播进行端到端训练。唯一的限制是MIL池必须是可微分的——便于梯度反向传播。
MIL pooling
MIL问题的公式化要求MIL池化 σ \sigma σ是置换不变的。常见的两种方式:最大池化和平均池化。
除此之外还有:
-
凸最大算子:log-sum-exp
-
Integrated Segmentation and Recognition
-
noisy-or和noisy-and
Attention-based MIL pooling
先前的MIL池化方法存在问题:预定义和不可训练的。
eg:max池化适用于示例级,但不适应包级。mean池化对聚合示例得分效果很差,但适用于包表示。
灵活和自适应的MIL池化方法可以通过调整任务和数据来获得更好的结果。这种MIL池化不同于之前的池化方法,即具备可解释性。
Attention mechanism
示例使用通过神经网络确定的权值,且保证所有的权值之和为1。加权平均满足定理1的要求,其中权重与嵌入一起是函数 f f f的一部分。
对于低维嵌入后的示例 H = { h 1 , . . . , h K } H=\{h_1,...,h_K\} H={h1,...,hK},本文的MIL Pooling处理:
z = ∑ k = 1 K a k h k (5) z=\sum_{k=1}^Ka_kh_k\tag{5} z=k=1∑Kakhk(5)
其中:
a k = e x p { w T t a n h ( V h k T ) } ∑ j = 1 K e x p { w T t a n h ( V h j T ) } (6) a_k={exp\{w^Ttanh(Vh_k^T)\}\over\sum_{j=1}^Kexp\{w^Ttanh(Vh_j^T)\}}\tag{6} ak=∑j=1Kexp{wTtanh(VhjT)}exp{wTtanh(VhkT)}(6)
其中 w ∈ R L × 1 w\in \mathbb{R}^{L\times 1} w∈RL×1, V ∈ R L × M V\in \mathbb{R}^{L\times M} V∈RL×M
我们利用双曲正切 t a n h ( ⋅ ) tanh(·) tanh(⋅)单元非线性来包括适当梯度流的负值和正值。所提出的构造允许发现实例之间的相似性。
有趣的是,所提出的MIL池对应于注意力机制的一个版本(Lin等人,2017;Raffel&Ellis,2015)。
主要的区别是,通常在注意力机制中,所有实例都是顺序相关的,而这里我们假设所有实例是独立的。
因此,一个自然产生的问题是,在没有实例之间的顺序依赖关系的情况下,注意力机制是否可以工作,以及它是否不会学习均值算子。
我们将在实验中解决这个问题。
Gated attention mechanism
存在问题:
tanh(·)非线性对于学习复杂关系可能是低效的
tanh(x)对于x∈[-1,1]是近似线性的,这可能会限制实例之间学习关系的最终表现力
推荐:使用门控机制加 t a n h ( ⋅ ) tanh(·) tanh(⋅),即:
a k = e x p { w T ( t a n h ( V h k T ) ⨀ s i g m ( U h k T ) ) } ∑ j = 1 K e x p { w T ( t a n h ( V h j T ) ⨀ s i g m ( U h j T ) ) } (7) a_k={exp\{w^T(tanh(Vh_k^T)\bigodot sigm(Uh_k^T)) \}\over\sum_{j=1}^Kexp\{w^T(tanh(Vh_j^T)\bigodot sigm(Uh_j^T))\}}\tag{7} ak=∑j=1Kexp{wT(tanh(VhjT)⨀sigm(UhjT))}exp{wT(tanh(VhkT)⨀sigm(UhkT))}(7)
其中 U ∈ R L × M U\in \mathbb{R}^{L\times M} U∈RL×M是参数。 ⨀ \bigodot ⨀是按元素相乘, s i g m ( ⋅ ) sigm(\cdot) sigm(⋅)是sigmod。门控机制引入了一种可学习的非线性,可能会消除前面第二条提到的tanh中的线性。
Interpretability
理想情况下,注意力机制会将较高的权值分配给关键示例( y k = 1 y_k=1 yk=1的示例)。
注意力机制允许根据实例级标签来容易地解释所提供的决策。
注意力网络不像基于实例的分类器那样提供分数,但它可以被视为分数的代理。
基于注意力的MIL池连接了实例级方法和嵌入级方法
Experiments
一个用神经网络参数化的MIL模型和一个基于(门控)注意力的池化层
实验目标:
(i)我们的方法是否达到了最佳性能或与性能最佳的方法相当
(ii)我们的算法是否可以通过使用指示关键实例或ROI的注意力权重来提供可解释的结果。
实验数据集:
-
benchmark(MUSK1,MUSK2,FOX,TIGER,ELEPHANT)
-
手写识别数据集:MNIST-BAGS
-
医学诊断数据集:Breast Cancer(乳腺癌),Colon Cancer(结肠癌)
如果使用基于注意力的MIL池化层,则使用验证集确定V中的参数数量。我们测试了以下维度(L):64、128和256。不同的尺寸只导致了模型性能的微小变化。对于使用门控注意力机制的层,V和U具有相同数量的参数。
由于注意力在反向传播过程中充当梯度更新过滤器,因此具有更高权重的实例将对学习实例的编码器网络做出更大贡献。
所提出的方法不仅是最准确的,而且获得了最高的召回率。高召回率在医学领域尤其重要,因为假阴性可能导致包括患者死亡在内的严重后果。
在图5中,我们显示了组织病理学图像,该图像被划分为包含(大部分)单细胞的斑块。我们通过将补丁乘以其相应的注意力权重来创建热图。尽管在训练过程中只使用了图像级别的注释,但图5(d)中的热图和图5(c)中的基本事实之间有很大的匹配。