前言
随着研究的深入,SAM
被拓展到医学图像分割领域。但研究表明,直接将 SAM
用于医学图像分割,效果非常差。但医疗数据难以获取以及高昂的注释成本迫切的需要一个基础模型来打卡局面,不仅仅在图像分割层面,更是在数据注释方面。
本文介绍了第一个将 Adaption
微调方法用于 SAM
的模型,该模型在 19 个数据集上都取得了惊人的效果。这为后续 SAM
的研究和 fine-tune 工作,提供了一个有效的参考和指导。
原论文链接:Medical SAM Adapter: Adapting Segment Anything Model for Medical Image Segmentation
1. Abstract & Introduction
1.1. Abstract
最近的许多评测任务表明,SAM
在医学图像分割方面的表现不尽如人意(如我先前研读的另一篇论文:MSA【1】:Segment Anything Model for Medical Image Analysis: an Experimental Study)。一个自然的问题是,如何将 SAM
出色的 zero-shot 能力迁移到医学图像领域中。
因此,需要寻找缺失的部分来扩展 SAM
的分割能力。本文提出了 Medical SAM Adapter (MSA)
,将医学特定领域知识通过简单而有效的适应技术集成到分割模型中,而不是仅对 SAM
模型微调。这个可能的解决方案,即根据 Adapter
的参数效率微调范式对预训练的 SAM
模型进行微调,这一简单的实现在医学图像分割上显示出了令人惊讶的良好性能。
1.2. Introduction
SAM
作为一种强大的通用视觉分割模型能够根据用户提示生成各种精细的分割掩膜,但最近的许多研究表明它在医学图像分割上表现不佳。
SAM
在医学图像上失败的主要原因是缺乏训练数据。虽然 SAM
在训练过程中建立了一个复杂高效的数据引擎,但他们收集到的医学应用案例却很少。
为了将 SAM
拓展到医学图像分割,本文选择使用自适应(Adaption
)的参数高效微调(PEFT
)范式对预训练的 SAM
进行微调。
Adaption
是自然图像处理(NLP)中一种流行且广泛使用的技术,用于针对特定用途对基本预训练模型进行微调。其主要思想是在原始基本模型中插入几个参数效率高的 Adapter
模块,然后只调整 Adapter
参数,冻结所有预训练参数。
1.2.1. Why do we need SAM for medical image segmentation?
交互式分割是所有分割任务的范式,而 SAM
提供了一个很好的框架,使其成为实现基于提示的医学图像分割的一个完美基准。
1.2.2. Why fine-tuning?
SAM
的预训练模型是通过精心设计的数据引擎在全球最大的分割数据集上训练出来的。同时,有相当多的研究表明,在自然图像上进行预训练也有利于医学图像分割。
1.2.3. Why PEFT and Adaption?
PEFT
是针对特定用途微调大型基本模型的有效策略- 与完全微调相比,
PEFT
保持了大部分参数的冻结状态,学习的参数大大减少,通常不到总参数的 5% - 学习效率更高,更新速度更快
PEFT
方法通常比完全微调效果更好,因为它们能避免灾难性遗忘,并能更好地泛化到域外场景,尤其是在低数据状态下
- 与完全微调相比,
Adaption
是针对下游任务微调大型基本视觉模型的有效工具
2. Method
2.1. Preliminary: SAM architecture
SAM
由三个主要部分组成:图像编码器、提示编码器和掩码解码器
- 图像编码器基于经
MAE
预先训练的标准视觉变换器(ViT
)- 图像编码器使用的是 ViTH/16 变体,它采用 14 × 14 14\times14 14×14 窗口注意力和四个等间距全局注意力块
- 图像编码器的输出是输入图像的 16 16 16 倍降采样嵌入
- 有关
ViT
的详细介绍可以参考我的另一篇 blog:CV-Model【6】:Vision Transformer
- 提示编码器可以是稀疏的(
point
、box
、text
),也可以是密集的(mask
) - 掩码解码器是一个
Transformer decoder block
- 包含一个动态掩码预测头
SAM
在每个 block 中使用双向交叉注意,一个用于提示到图像的嵌入,另一个用于图像到提示的嵌入,以学习提示和图像嵌入之间的交互- 在运行两个 block 后,
SAM
会对图像嵌入进行上采样,然后由MLP
将输出标记映射到动态线性分类器,从而预测给定图像的目标掩码
有关 Segment Anything Model
的详细介绍可以参考我的另一篇 blog:SAM【1】:Segment Anything
2.2. MSA architecture
为了对 SAM
架构进行微调以适应医学图像分割,本文没有完全调整所有参数,而是冻结了预先训练好的 SAM
参数,并在架构的特定位置插入了 Adapter
模块。
Adapter
是一个 bottleneck 结构,它依次包括:下投影、ReLU
激活和上投影。下投影使用简单的 MLP
层将给定的嵌入压缩到较小的维度;上投影使用另一个 MLP
层将压缩的嵌入扩展回其原始维度。
2.2.1. 2D Medical Image Adaption
在 SAM
编码器中,本文为每个 ViT
块部署了两个 Adapter
修改标准 ViT block (a)
,得到 2D Medical Image Adaption (b)
- 将第一个
Adapter
放在多头注意力之后、残差连接之前 - 将第二个
Adapter
放在多头注意力之后MLP
层的残差路径上 - 紧接着第二个
Adapter
之后,按照一定的比例系数 s s s 对嵌入进行了缩放- 引入缩放因子 s 是为了平衡与任务无关的特征和与任务有关的特征
- 默认值为 0.1(参考论文中效果最好的情况)
2.2.2. Decoder Adaption
在 SAM
解码器中,本文为每个 ViT
块部署了三个 Adapter
修改标准 ViT block (a)
,得到 Decoder Adaption (b)
- 第一个
Adapter
部署在prompt-to-image
嵌入的多头交叉注意之后,并添加了提示嵌入的残差- 本文使用了另一种向下投影来压缩提示嵌入,并在
ReLU
激活之前将其添加到Adapter
的嵌入上 - 有助于
Adapter
根据提示信息调整参数,使其更加灵活和通用于不同的模式和下游任务
- 本文使用了另一种向下投影来压缩提示嵌入,并在
- 第二个
Adapter
的部署方式与编码器完全相同,用于调整MLP
增强嵌入 - 第三个
Adapter
部署在图像嵌入的残差连接之后,以提示交叉注意 - 另一个残差连接和层归一化在自适应后连接,以输出最终结果
2.2.3. 3D Medical Image Adaption
尽管 SAM
可以应用于病灶的每个切片以获得最终的分割,但是它没有考虑深度维中的相关性
本文提出了一种新的适配方法,其灵感来源于 image-to-video adaptation
,具体架构如 (c)
- 在每个 block 中,本文将注意力操作分成两个分支:空间分支和深度分支
- 对于深度为
D
D
D 的给定 3D 样本
- 将维度为
D
×
N
×
L
D \times N \times L
D×N×L 的数据发送给空间分支中的
Multi-head Attention
- 其中
N
N
N 是
embedding
的数量, L L L 是embedding
的长度, D D D 是运算的个数 - 交互作用应用在 N × L N \times L N×L 上学习和抽象空间相关性作为嵌入
- 在多头注意力机制中,每个头的注意力计算都涉及三个线性变换:查询(query)、键(key)和值(value),这些变换会映射输入序列到不同的表示子空间。在这些表示子空间中,不同头之间会发生交互作用,以便在模型中学习和捕捉不同的特征
- 具体来说,交互作用发生在两个阶段:
- 头部内交互(
Intra-head Interaction
):在每个头中,注意力权重是通过查询与键的点积计算得到的。在这个计算过程中,查询和键会相互交互,以确定查询与键之间的相关性。这样,每个头可以根据输入序列的不同部分选择性地关注或忽略信息 - 头部间交互(
Inter-head Interaction
):在多头注意力中,每个头计算出的注意力权重和对应的值会通过加权求和的方式进行合并。这种合并可以看作是不同头之间的交互作用,它使得每个头能够综合不同的特征表示。通过多头的合并,模型可以在综合多个头的信息后得到更全面和丰富的上下文表示
- 头部内交互(
- 其中
N
N
N 是
- 在深度分支中,首先对输入矩阵进行转置,得到维度为
N
×
D
×
L
N \times D \times L
N×D×L 的数据,然后将其发送给相同的
Multi-head Attention
- 虽然使用相同的注意力机制,但是相互作用是在 D × L D \times L D×L 上应用的
- 通过这种方式,深度相关性被学习和抽象
- 将维度为
D
×
N
×
L
D \times N \times L
D×N×L 的数据发送给空间分支中的
- 最后,将来自深度分支的结果转换回它们的原始形状,并将它们添加到空间分支的结果中
2.3. Training Strategy
2.3.1. Encoder Pretrainig
使用医学图像在编码器上进行预训练
本文结合多种自监督学习方法进行预训练
Contrastive Embedding-Mixup (e-Mix)
e-Mix
是一种对比目标,用于文本数据的无监督表示学习- 对比目标是一种自监督学习的目标函数,常用于训练无标签数据的表示学习模型
- 在对比目标中,每个样本都会被转换成多个视图,这些视图之间通常被称为
anchor
(锚点)和positive
(正样本) - 它们都来自于同一个原始样本,但可能经过不同的变换(数据增强)以产生不同视图
- 然后,对于每个
anchor
,我们需要将其与其对应的正样本进行对比,使得它们的表示在表示空间中更加接近,而与其他样本的表示更加分散
- 在对比目标中,每个样本都会被转换成多个视图,这些视图之间通常被称为
e-Mix
可以对一批原始输入嵌入进行加权混合,并使用不同的系数对它们进行加权。然后训练编码器产生一个混合嵌入向量,该向量与原始输入的混合系数成比例地接近原始输入的嵌入
- 对比目标是一种自监督学习的目标函数,常用于训练无标签数据的表示学习模型
Shuffled Embedding Prediction (ShED)
ShED
混合了一部分嵌入,并用分类器训练编码器来预测哪些嵌入受到了干扰ShED
旨在从未标明的文本数据中学习有用的上下文表征ShED
的目标基于预测句子中洗牌词或标记嵌入的原始顺序这一理念
MAE
MAE
掩盖了输入嵌入的一个给定部分,并训练模型来重建它们
2.3.2. Training with Prompt
本文在新的医学图像数据集上适应 SAM
,这个过程基本上与 SAM
中的过程相同,本文只考虑了 2 种 prompt:point
& text
(text
部分的代码原作者目前似乎还未公开)
point
- 本文使用随机和迭代点击采样策略的组合训练此 prompt
- 具体流程如下:
- 首先使用随机采样进行初始化,然后使用迭代采样过程添加一些
point
- 迭代采样策略类似于与真实用户的交互,因为在实践中,每个新的
point
都被放置在网络使用以前的点击集产生的预测的错误区域 - 生成随机抽样和模拟迭代抽样
- 首先使用随机采样进行初始化,然后使用迭代采样过程添加一些
text
- 在
SAM
中,作者使用CLIP
产生的目标对象作物的图像嵌入作为嵌入在CLIP
中接近其相应文本描述或定义的图像。然而,由于CLIP
很少在医学图像数据集上进行训练,它很难将图像上的器官/病变与相应的文本定义联系起来 - 医学图像分割需要精确识别和标记图像中的不同结构,而不准确的标签会导致错误的诊断和治疗
- 为了克服这一局限性,本文提出了一种不同的训练策略,即生成包含目标定义的自由文本,作为 Chat-GPT 中的关键字,然后使用
CLIP
作为训练提示提取文本的嵌入- 通过使用这种方法,可以确保文本定义与医学图像相关,并且模型可以准确地将图像中的不同结构与其相应的文本定义联系起来
- 在
总结
通过使用 PEFT & Adaption
(一种具有成本效益的微调技术),本文在原有的 SAM
模型上取得了显著的改进,并在 5 种不同图像模式下的 19 个医疗图像分割任务上取得了最先进的性能。
这些结果表明了本文的方法适应医学图像的有效性和潜力转移强大的一般分割模型的医疗应用。