[医学分割大模型系列] -3- SAM-Med3D 分割大模型解析
论文地址:SAM-Med3D
开源地址:https://github.com/uni-medical/SAM-Med3D
发表日期:2023年10月
参考资料:
- 王皓宇(上海交通大学)SAM-Med3D基于SAM构建3D医学影像通用分割模型
- SAM-Med3D:三维医学图像上的通用分割模型,医疗版三维 SAM 开源了!
- SAM-Med3D (SJTU 2024)
1. 特点
- 通用分割能力:在各种3D目标上精准分割,效果明显优于SAM,SAM-Med2D(相对于切片进行2D分割)
- 更高的效率:比现有通用分割模型更快,提示需求更少(相对于切片进行2D分割)
- 迁移能力:作为预训练模型,在多个任务上效果良好
- 模型输入:要分割的图像和一个/几个提示点(提示点越多,效果越好)
- 模型输出:分割结果
- 数据集:SAM-Med3D-130K数据集,拥有 131K 3D mask和 247 个类别
- 网络结构:类SAM,将结构换成3D版本
- 分割对象:3D医学图像
2. 背景
- 3D医学图像:体素形式的3D图像和标注,以不同分布的灰度图像为主
- 任务特定模型的局限:
- 沉重的训练负担:使用U-Net,UNETR等分割网络在医学数据集上训练,使用A100也需要2-7天
- 泛化性弱
使用特定数据集训练出来的模型(左列)在其他数据集上的表现(行)不佳
- 沉重的训练负担:使用U-Net,UNETR等分割网络在医学数据集上训练,使用A100也需要2-7天
- SAM在3D医学分割的局限:
- 由于医学图像知识的严重不足,将 SAM 直接应用于医学领域的有效性有限。解决这个问题的一种直接的方法是:将医学知识融入到 SAM 中。比如,MedSAM 是一种典型示例,它通过使用110万个掩码(mask)对SAM 的解码器(Mask Decoder)进行微调,从而使 SAM 能够通过边界框(Bounding Box)作为提示来更好地分割医学影像;SAM-Med2D 则引入了适配器(Adapter)和约2000万个掩码(mask)对 SAM 进行了充分微调,从而在医学图像分割中表现出了卓越的性能。
- 然而,这些方法必须采用逐切片(slice)的方法来处理三维医学图像,也即,将三维数据从某个维度分解为二维切片,然后独立处理每个切片,最后将二维分割结果汇总为三维分割结果。这种方法忽略了切片之间的三维空间信息,因此在三维医学影像上表现不佳,这一问题可以从上图中的结果看出。SAM和SAM-Med2D都是一张张切片进行分割,每张切片都需要一个提示,所以总共需要N个提示。对于一些切片,他们的表现不佳,从而导致空间信息的不连贯。
- 除了将 SAM 直接应用于三维数据,一些研究人员希望通过引入二维到三维的适配器(Adapter)来捕捉三维空间信息。这些方法通常在保持编码器(Image Encoder)不变的同时引入了三维适配器(Adapter),以使模型能够从三维图像中学习到三维空间信息。然而,这些方法存在两个主要限制:(1)数据规模有限:这些方法的模型通常只在有限的数据规模下进行训练(通常在1K到25K个mask范围内),并且只针对有限的目标类型。这限制了模型的泛化性能和适用范围。(2)冻结的二维编码器:现有的三维 SAM-based 模型一直坚守着冻结原始二维 SAM 编码器(Image Encoder)的设计范式,这限制了模型全面建模三维空间信息的能力,大大限制了 SAM 在三维医学图像处理领域的发展潜力。
3. 训练数据集
3.1 数据集收集
作者进行了三维医学图像数据集的广泛收集和标准化工作,整合了116个公开和私有的三维医学图像数据集,经过4轮数据筛选和清晰,创建了迄今为止规模最大的三维医学图像分割数据集。该数据集包含了 2.1 万个三维医学图像(病人数量)和 13.1 万个三维掩码(mask)。从下表可以清晰地看出,这一数据集的规模远远超过了现有最大的三维医学图像分割数据集,如 TotalSegmentator 和 BraTS21,其规模扩大了 10 倍以上。
该数据集涵盖 27 种模态(CT 和 26 种MRI 序列)和 7 种解剖结构。如下图所⽰,共涵盖了 247 个不同的类别,包括器官和病变。
3.2 数据清洗
四步数据清洗:
- 基于元信息的数据清理 我们首先总结了所收集数据的元信息,包括每张医学影像的深度、宽度和高度。我们删除了所有物理尺寸小于 1 立方厘米或任何单个尺寸小于 1.5 厘米的病例,以确保目标mask的可见性。
- 基于连接域的掩码清理 在计算连通域的过程中,我们首先将原始的多类mask分割成多个类别的单击格式。然后,我们计算每个单击掩码的前 5 个最大连通域的大小和背景。根据这些掩码的汇总信息,我们会删除背景占整个体积 99% 以上的mask。
- 基于连接域的标签质量改进 对于过滤后的mask,我们设计了一个基于连接域的pipeline来提高标签质量。根据每个mask的前 5 个最大连通域的汇总信息,我们只需删除小于这 5 个连通域的任何其他域,以减少噪音。
- 基于对称性的标签质量改进 最后,我们将一些对称目标的mask拆分为不同类别的成对mask。例如,我们将 "肾 "的mask分为 "左肾 "和 “右肾”。这一步的目的是加强不同类别mask的语义一致性,防止模型分不清是分割整个结构还是只分割单个的左右部分。为了解决这个问题,SAM 为每个提示生成多个预测,并采用额外的头部生成分数,以方便选择最合适的预测。鉴于医学图像的mask通常不那么模糊,我们选择直接处理数据来消除这种模糊性,从而增强mask类别之间的语义一致性,降低网络训练的复杂性。
3.3 模型微调数据集
目前SAM-Med3D-turbo是现已发布经过微调的 SAM-Med3D 的最新版本checkpoint。在SAM-Med3D的基础上又在 44 个数据集 ( 以下list )上对其进行了微调以提高性能。
AMOS2022
ATM2022
AbdomenCT1K
BTCV_Cervix
BraTS2020
BraTS2021
BrainTumour
Brain_PTM
CAUSE07
CHAOS_Task_4
COSMOS2022
COVID19CTscans
CTPelvic1k
CT_ORG
FLARE21
FLARE22
Heart_Seg_MRI
ISLES_SISS
ISLES_SPES
KiPA22
KiTS
KiTS2021
LAScarQS22_task1
LAScarQS22_task2
LITS
MMWHS
MSD_Colon
MSD_HepaticVessel
MSD_Liver
MSD_Pancreas
MSD_Prostate
MSD_Spleen
PROMISE12
Parse22
Promise09
Prostate_MRI_Segmentation_Dataset
SLIVER07
STACOM_SLAWT
SegThor
Totalsegmentator_dataset
VESSEL2012
VerSe19
VerSe20
WORD
4. 模型结构
基于SAM修改后SAM-Med3D 的 3D 架构。 原始2D组件被转换为3D对应组件,包括3D Image Encoder、3D Prompt Encoder 和3D mask Decoder。采用3D卷积、3D位置编码(PE)和3D layer norm来构建3D模型。
4.1 3D Image Encoder
在 3D 图像编码器中,首先使用内核大小为 (16, 16, 16) 的 3D 卷积嵌入块生成embedding,并与可学习的 3D 绝对位置编码 absolute Positional Encoding (PE) 配对。 这种编码是通过自然地将附加维度扩展到 SAM 的 2D PE 来获得的。 然后将补丁的嵌入输入到 3D 注意力块中。 对于 3D 注意力模块,我们将 3D 相关 PE 合并到 SAM 的多头自注意力(MHSA)模块中,使其能够直接捕获空间细节。
class PatchEmbed3D(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16, 16),
stride: Tuple[int, int] = (16, 16, 16),
padding: Tuple[int, int] = (0, 0, 0),
in_chans: int = 1,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv3d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C X Y Z -> B X Y Z C
x = x.permute(0, 2, 3, 4, 1)
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.