项目实训(十五)——学习如何进行SAM模型的微调

一、我们为什么要微调SAM模型

SAM(Segment Anything Model)是一种通用图像分割模型,能够对任何类型的图像进行分割。然而,医疗领域有一些独特的需求和挑战,使得在这个领域对SAM进行微调是必要的。以下是一些原因:

数据的特殊性:

医疗图像,如MRI、CT扫描、X光片等,与日常图像有很大的不同。它们通常包含更复杂和细致的结构,需要特定的医学知识来正确理解和标注。
医疗图像的纹理、对比度和噪声特征也不同于一般的自然图像,因此需要专门训练来适应这些特征。
高精度需求:

在医疗领域,图像分割的精度非常关键,因为错误可能会导致误诊或漏诊,从而对患者的健康产生严重影响。
SAM虽然是通用模型,但在高精度要求的情况下,需要进行微调以提高其在特定医疗任务中的准确性和可靠性。
标注数据的差异:

医疗图像的标注通常由专业的医学专家完成,标注标准和细节要求更高。
SAM需要适应这些高标准的标注,以提供符合医疗诊断需求的分割结果。
特定病症的检测:

不同的疾病在图像上的表现形式各异,需要针对特定病症的分割能力。例如,肿瘤、病变区域、器官边界等的分割任务都有其独特的挑战。
通过微调,可以使SAM在特定病症的检测和分割上表现更好。

二、如何对SAM进行微调

SAM在医疗领域进行微调的操作并非是一个开创性的工作,我们目前是有很多论文可以参考借鉴的。

比较好的方法就是有adapter,PEFT,阶梯型优化的方法。

我们先来看看现有的一些方法是如何对SAM进行微调的。

《基于SAM的医学图像分割--阶梯式微调》

SAM由三个部分组成,这些部分包括图像编码器、提示编码器和遮罩解码器。图像编码器采用经过MAE预训练的ViT网络来提取图像特征。提示编码器支持四种类型的提示输入:点、框、文本和遮罩。点和框使用位置编码进行嵌入,而文本则使用CLIP中的文本编码器进行嵌入。遮罩使用卷积操作进行嵌入。遮罩解码器旨在以轻量级方式映射图像嵌入和提示嵌入。这两种类型的嵌入通过交叉关注模块进行交互,使用一个嵌入作为查询,另一个嵌入作为键和值向量。最终,使用转置卷积对特征进行上采样。遮罩解码器具有生成多个结果的能力,因为提供的提示可能存在歧义。默认的输出数量设置为三个。值得一提的是,图像编码器对每个输入图像只提取一次图像特征。之后,轻量级的提示编码器和遮罩解码器可以根据不同的输入提示与用户实时在网页浏览器中进行交互。

《使用adapter微调SAM应用于医学图像(2023+Medical SAM Adapter: Adapting SegmentAnything Model for Medical Image)》

方法主要是在ViT块中嵌入Asapter块,模型冻结其他参数,只对adaper块进行更新。

如图b中所示,adapter有down、relu、up三部分构成 。down使用简单的MLP层将给定的嵌入压缩到更小的维度;up使用另一个MLP层将压缩的嵌入扩展回其原始维度,relu是指的relu函数。

图a是原始SAM中的Vit块。

图b表示应用2D医学图像的修改,在多头注意力机制和残差块之后分别插入adapter,并在adater之后进行缩放。

图c表示应用3D医学图像的修改,主要考虑深度相关的影响。将一个VIT块分成两个分支,depth branch 和 space branch,对于给定深度为D的3D样本,我们将D x N x L发送到空间分支中的多头注意力,其中N为嵌入的数量,L为嵌入的长度。在这里,D是操作的数量,并且在N x L上应用交互来学习和抽象空间相关性作为嵌入。在深度分支中,我们首先对输入矩阵进行转置,得到N x D x L,然后将其发送到相同的多头注意。虽然我们使用相同的注意机制,但交互作用应用于D x l。通过这种方式,深度相关性被学习和抽象。最后,我们将深度分支的结果转回其原始形状,并将其添加到空间分支的结果中。

图d表示用于提示的修改。如图加入三个adapter。

训练策略:

图像编码器:与SAM中使用的MAE预训练不同,我们使用了几种自监督学习方法的组合进行预训练。前两种分别是对比嵌入-混合预测(e-Mix)和洗牌嵌入预测(ShED)[32]。e-Mix是一种对比目标,它将一批原始输入嵌入进行加性混合,并用不同的系数对它们进行加权。然后,它训练编码器为混合嵌入生成一个向量,该向量与原始输入的嵌入按混合系数的比例接近。ShED对一小部分嵌入进行洗刷,并用分类器训练编码器来预测哪些嵌入受到了干扰。在SAM的原始实现之后,我们还使用了掩码自编码器(MAE),它掩码给定部分的输入嵌入并训练模型来重建它们。

提示编码器:对于单击提示,正数单击表示前景区域,负数单击表示背景区域。我们使用随机和迭代点击抽样策略的组合来训练这个提示。具体来说,我们首先使用随机抽样进行初始化,然后使用迭代抽样过程添加一些点击。迭代采样策略类似于与真实用户的交互,因为在实践中,每次新的点击都被放置在由网络使用先前点击集产生的预测的错误区域中。我们生成随机抽样,模拟迭代抽样。我们在SAM中使用了不同的文本提示训练策略。在SAM中,作者使用CLIP生成的目标对象作物的图像嵌入作为接近其在CLIP中对应的文本描述或定义的图像嵌入。然而,由于CLIP几乎没有在医学图像数据集上进行训练,因此它很难将图像上的器官/病变与相应的文本定义联系起来。相反,我们首先从ChatGPT中随机生成几个包含目标(即视盘,脑肿瘤)定义作为关键字的自由文本,然后使用CLIP作为训练提示提取文本的嵌入。一个自由文本可以包含多个目标,在这种情况下,我们用所有相应的掩码来监督模型。

### 使用 Docker 对 SAM 模型进行微调 为了在 Docker 中对 SAM (Segment Anything Model) 进行微调,可以遵循一系列特定的操作流程。这不仅涉及创建和配置 Docker 容器环境,还包括准备数据集以及编写必要的脚本来执行微调过程。 #### 创建自定义 Dockerfile 构建适合于 SAM 微调工作的 Docker 镜像的第一步是从合适的基镜像开始,并安装所有必需依赖项: ```dockerfile FROM nvidia/cuda:11.7.0-cudnn8-devel-ubuntu20.04 # 设置工作目录 WORKDIR /workspace # 更新包列表并安装基础工具 RUN apt-get update && \ apt-get install -y python3-pip git wget libgl1-mesa-glx libglib2.0-0 ffmpeg # 升级pip到最新版本 RUN pip3 install --upgrade pip setuptools wheel # 安装PyTorch和其他必要库 RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 # 克隆SAM仓库并进入项目文件夹 RUN git clone https://github.com/facebookresearch/segment-anything.git . ``` 上述命令设置了基于 CUDA 和 cuDNN 的 Ubuntu 环境,并安装了 PyTorch 及其扩展模块以及其他可能需要的 Python 库[^1]。 #### 准备数据集与配置文件 对于微调而言,准备好用于训练的数据集至关重要。通常情况下,会有一个标注好的图片集合作为输入给定模型学习新特征。此外,还需要修改默认配置文件 `model_cfg` 来适应新的任务需求,比如调整超参数、指定路径至权重文件等。 ```python import yaml with open('sam2_hiera_s.yaml', 'r') as file: cfg = yaml.safe_load(file) cfg['DATASETS']['TRAIN'] = ("path/to/train/dataset", ) cfg['SOLVER']['MAX_ITER'] = 50000 cfg['OUTPUT_DIR'] = "./output" with open('customized_sam_config.yaml', 'w') as file: documents = yaml.dump(cfg, file) ``` 这段代码展示了如何读取原始 YAML 格式的配置文档,在其中加入针对当前实验设定的新值之后再保存回磁盘上一个新的位置以便后续使用。 #### 编写微调脚本 最后一步就是编写实际用来启动微调进程的 Python 脚本。这里假设已经有了经过适当预处理后的图像及其对应的标签信息存放在本地硬盘里等待被送入网络内部参与迭代更新操作。 ```python from detectron2.engine import DefaultTrainer from detectron2.config import get_cfg from samgeo import SamGeo2 # 假设这是从引用中提到的一个类 def main(): trainer = DefaultTrainer(cfg=get_cfg()) # 加载定制化的配置选项 with open("customized_sam_config.yaml", "rb") as f: customized_cfg = yaml.load(f.read(), Loader=yaml.FullLoader) for k, v in customized_cfg.items(): setattr(trainer.cfg, k, v) checkpoint_path = "sam2_hiera_small.pt" model_weights = torch.load(checkpoint_path)['model'] trainer.model.load_state_dict(model_weights) trainer.resume_or_load(resume=False) trainer.train() if __name__ == "__main__": main() ``` 此段程序片段说明了怎样利用 Detectron2 提供的功能完成整个端到端的学习循环;同时也体现了当面对不同应用场景时应该如何灵活运用第三方库来简化编码难度[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值