MIL文献阅读(1) TransMIL
文献名:TransMIL: Transformer based Correlated Multiple Instance Learning for Whole Slide Image Classification
期刊:NeurIPS
发表时间:2021
任务类型:多实例学习(MIL)
前人存在的问题
-
之前的所有假设都假设instance都是独立同分布的(i.i.d);但事实上并不是。
-
之前的transformer的计算效率太低;
方法
问题构建(problem formulation)
MIL问题构建:
定理引入
定理一 说明可以利用文本的结构近似分数:
定理二 说明利用 相关attention的优势:相较于独立同分布假设可以减少无效信息;
方法介绍
算法流程图:
其中,函数 f f f 和 h h h 分别用来获取形态学特征和位置特征;
- 算法分为3步:
- 利用函数 f f f 和 函数 h h h 进行特征提取;
- 利用池化矩阵 P P P 进行池化,这里的池化矩阵是由自注意力机制得到的;
- 利用函数 g g g获得分类结果;
算法整体结构图如图所示:
从图中不难看出,函数的流程图和算法伪代码很好地对应。
使用预训练的ResNet50进行特征提取,对应函数 f f f和函数 h h h;利用自注意力机制进行特征映射,也就是函数 P P P;利用多层感知机实现分类,对应函数 g g g。
TPT模块
中间的特征映射部分由两个Transformer层和一个position编码层(PPEG)组成,因此被成为TPT模块。
TPT模块伪代码:
该模块作为文章中最重要的一个模块,下面将要详细介绍。
- Squaring of sequence
这步操作将整个特征矩阵 H S H_S HS 的特征向量的个数(去掉class token)可以被开方,方便后续关于位置编码的操作;
- attention机制
attention机制使用经典的多头注意力机制,详细说明在文章的附录中。
- PPEG获取位置信息
PPEG 也就是 Pyramid Position Encoding Generator,用来提供位置编码信息。
这样的设计起初被用在自然图像中,这里引入到医疗图像中,因此没有办法像自然图像那样保持原有的token之间的位置关系。但是通过文章的消融实验发现有不错的效果,因此该部分被保留。
算法整体流程见下图:
将patch tokens聚合成 N \sqrt{N} N大小的特征矩阵,经过不同大小的卷积网络卷积,求和得到最终融合了位置特征的特征矩阵。
需要回答的问题:
- 是如何应对每个instance不是独立同分布的?
使用多头注意力机制,同时考虑多个patch之间的关系。
- 如何保证使用attention机制的运算量较小?
利用 Nystrom Method 实现attention机制的近似计算。