SAM-Med2D 大模型学习笔记(续):训练自己数据集

1、前言、数据集介绍

SAM-Med2D大模型介绍参考上文:第三章:SAM-Med2D大模型复现-CSDN博客

本文将使用SAM-Med2D大模型训练自己的数据集

关于SAM-Med2D大模型官方demo数据集的介绍上文已经介绍过,这里简单回顾下

  • 其中data_demo为数据集的目录,下面有images和masks两个目录,分别存放数据和标签
  • 其中images,就是正常的数据图像,格式是png格式
  • masks格式值得注意,正常的mask是灰度等级的阈值图像【0 1 2 3】,这里把每个类别单独提取出来,变成【0 255】的二值图像,有几个类别就有几张对应的mask模板

例如mask是【0 1 2 2 1】,mask模板有两个,分别是1对应的模板【0 255 0 0 255】,就是只分割前景1。以及只是分割2的模板【0 0 255 255 0】。

mask的命名可以是image名字加上灰度,例如image_1.png和image_2.png

两个json文件如下:

训练数据就是单张image对应的一组mask标签字典

测试集是mask对应的image

2、生成数据的脚本

有了上面介绍,就很简单了,也就是说我们只需要把自己的数据集换成上面格式就可以正常训练了!

其他补充,因为官方的image和mask都是png格式的。

格式需要是png,因为之前本人做过实验只有png保存的二值图像,灰度值才不会乱掉(比如你保存【0 255 0】的jpg读取,np.unique读取可能变成【0 224 223】之类的)

更改文件后缀可以参考:PYTHON 自动化办公:更改图片后缀_改变文件夹里面图片后缀名的pytorch代码-CSDN博客

这里需要把自己的数据集摆放如下:

划分数据集的脚本参考:关于图像分割任务中按照比例将数据集随机划分成训练集和测试集_图像数据划分训练集-CSDN博客

然后运行下面代码就行了:

这个代码会生成image对应mask不同类别的掩膜数据,并且生成两个json文件。这里的目录命名一定要和上面对应

import json
import numpy as np
from tqdm import tqdm
import os
import shutil
from PIL import Image
import cv2


def mkdir():
    root = 'data_demo'
    if os.path.exists(root):
        shutil.rmtree(root)
    os.mkdir(root)
    os.mkdir(os.path.join(root,'images'))
    os.mkdir(os.path.join(root,'masks'))


# 生成训练集
def gen_trainSet(img_suff,msk_suff):
    p = 'RawData/train/images'
    image_list = [os.path.join(p,i) for i in os.listdir(p)]

    with open('data_demo/image2label_train.json', 'a') as jf:
        json_all ={}        # json文件
        for i in tqdm(image_list,desc='generate train set'):
            j = i.replace('images','masks').replace(img_suff,msk_suff)
            assert os.path.exists(j)        # 判断label是否存在

            shutil.copy(i,'data_demo/images')

            mask = np.array(Image.open(j).convert('L'))     # 标签图像
            gray_list = np.unique(mask)

            img_list = []
            for gray in gray_list[1:]:          # 遍历mask所有的分割前景
                ret_mask = np.zeros(mask.shape,dtype=np.uint8)

                ret_mask[mask==gray] =255      # 指定前景为255,其余为背景
                ret_mask[ret_mask<255] = 0

                # 去除小的分割区域
                h,w = ret_mask.shape
                total_pixel = h*w
                if (np.sum(ret_mask!=0)/total_pixel) < 0.005:
                    continue

                ret_name =i.replace(img_suff,'_'+str(gray)+img_suff).replace('RawData/train/images','data_demo/masks')
                cv2.imwrite(ret_name,ret_mask)  # 保存生成的数据

                img_list.append(ret_name)
            if len(img_list) == 0:
                continue
            json_all[i.replace('RawData/train/images','data_demo/images')] = img_list

        json_str = json.dumps(json_all,indent=4)
        jf.write(json_str)


# 生成测试集
def gen_testSet(img_suff,msk_suff):
    p = 'RawData/test/images'
    image_list = [os.path.join(p,i) for i in os.listdir(p)]

    with open('data_demo/label2image_test.json', 'a') as jf:
        json_all ={}        # json文件
        for i in tqdm(image_list,desc='generate test set'):
            j = i.replace('images','masks').replace(img_suff,msk_suff)
            assert os.path.exists(j)        # 判断label是否存在

            shutil.copy(i,'data_demo/images')

            mask = np.array(Image.open(j).convert('L'))     # 标签图像
            gray_list = np.unique(mask)

            for gray in gray_list[1:]:          # 遍历mask所有的分割前景
                ret_mask = np.zeros(mask.shape,dtype=np.uint8)

                ret_mask[mask==gray] =255      # 指定前景为255,其余为背景
                ret_mask[ret_mask<255] = 0

                # 去除小的分割区域
                h,w = ret_mask.shape
                total_pixel = h*w
                if (np.sum(ret_mask!=0)/total_pixel) < 0.005:
                    continue

                ret_name =i.replace(img_suff,'_'+str(gray)+img_suff).replace('RawData/test/images','data_demo/masks')
                cv2.imwrite(ret_name,ret_mask)  # 保存生成的数据

                json_all[ret_name] = i.replace('RawData/test/images','data_demo/images')

        json_str = json.dumps(json_all,indent=4)
        jf.write(json_str)


if __name__ == '__main__':
    imgFormat = '.png'          # image 的后缀
    maskFormat = '.png'         # mask 的后缀

    mkdir()         # 生成目录

    gen_trainSet(img_suff=imgFormat,msk_suff=maskFormat)        # 生成训练数据

    gen_testSet(img_suff=imgFormat,msk_suff=maskFormat)         # 生成测试数据

Tips

运行过程如下

如下:

可以看到image生成了三个对应的mask数据,命名是image的名字加上类别。

下图的8 9 17后缀是原来mask中8 9 17的像素值

测试代码的时候,训练会报错误,大概是len(box)什么分母为零,不能被除的bug。本人猜测可能是生成的组mask里面,前景区域太小之类的,所有脚本里增加点处理

代码会将不足千分之五的分割前景区域删除

3、训练脚本

因为生成的数据就是data_demo目录,所有train脚本不需要任何更改,直接运行即可

这里的parser.add_argument("--mask_num", type=int, default=5, help="get mask number")参数还是没懂

生成的结果如下:每个权重大约2G左右吧

4、测试脚本

代码如下:

python test.py --sam_checkpoint workdir/models/sam-med2d/epoch10_sam.pth

测试结果如下:

1. 概述 SAM-Med 2D视觉大模型是指一种用于医疗图像分析的深度学习模型,专门用于实现脊椎图像的分割任务。该模型能够自动识别并分割出脊椎影像中的各个部分,为医疗诊断提供辅助。本文档介绍了如何复现、训练SAM-Med 2D模型,并应用到自定义的数据集上。 2. 模型复现 复现一个深度学习模型意味着重现模型的训练过程和结果,通常需要具备以下条件: - 模型结构:详细描述了SAM-Med 2D模型的网络架构,包括各层的配置参数和激活函数。 - 训练数据集:涉及RawData目录下提供的数据,这些数据应包含未处理的脊椎影像。 - 训练脚本:提供一个训练脚本,用于设置训练参数,如学习率、批量大小、优化器等,并执行模型训练过程。 - 训练细节:可能还包括预处理步骤,如归一化、裁剪、增强等,以及模型的保存、加载和验证策略。 3. 训练自定义数据集 为了训练SAM-Med 2D模型使用特定的数据集,需要执行以下步骤: - 数据准备:确保自定义数据集符合模型训练的要求,可能需要按照RawData下的格式组织数据。 - 数据预处理:运行process脚本,这个脚本可能会执行数据的转换、格式化等步骤,以生成模型训练所需的数据格式。 - 训练执行:使用train脚本启动训练过程。训练过程中会用到之前生成的数据,按照模型的配置进行参数更新和优化。 4. 模型使用和评估 文档中提及了使用doc文件查看模型的分割和预测结果,这可能包括: - 结果展示:通过可视化的方式展示模型对脊椎图像的分割效果。 - 评估指标:提供一些量化评估指标(如准确度、召回率、Dice系数等),以评估模型在自定义数据集上的性能。 - 模型部署:对于将模型部署到实际的医疗诊断中的方法和步骤的描述。 5. 标签解析 - 数据集:指用于训练和测试模型的脊椎影像集合,这些数据集应该是标注好的,以区分不同的脊椎部分。 - 分割:是指在图像处理中将图像分割成多个部分,特别是将感兴趣的区域(如脊椎)从背景中分离出来。 - 大模型:通常指的是参数量大、结构复杂的深度学习模型,这类模型因其复杂性通常需要大量的数据和计算资源进行训练。 6. 文件结构 - SAM-Med2D:这是压缩包子文件的名称,可能包含了多个子目录和文件,用于支持模型的复现和训练- RawData:包含原始脊椎影像数据。 - process脚本:用于将RawData转化为模型训练可以接受的数据格式。 - train脚本:用于启动模型训练过程。 - doc文件:可能包含模型的分割和预测结果、训练细节和评估结果的文档。 7. 结论 复现和训练SAM-Med 2D视觉大模型涉及一系列复杂步骤,从数据的准备、预处理、到训练和评估。这些步骤共同确保了模型能够在特定任务上达到预期的性能。通过这些流程,研究人员和开发者能够更好地理解和利用深度学习技术进行医学图像分析。
### 使用 SAM 模型训练自定义数据集 为了使用 Segment Anything Model (SAM) 训练自定义数据集,需遵循一系列特定步骤来准备环境、预处理数据并调整模型配置。 #### 准备工作 确保安装必要的依赖库和工具链。对于 Python 环境而言,推荐创建独立的虚拟环境以避免版本冲突。可以利用 `conda` 或者 `venv` 创建新的Python环境[^1]。 ```bash conda create --name sam_env python=3.8 conda activate sam_env pip install -r requirements.txt ``` 其中 `requirements.txt` 文件应包含所有必需包及其兼容版本号。 #### 数据预处理 针对不同类型的输入数据,可能需要执行不同的转换操作。特别是当涉及医学影像或其他特殊格式文件时,如 `.nii.gz` 格式的三维体素数据,则要将其转化为二维切片保存为 `.npy` 数组形式以便于后加载与处理[^2]。 ```python import nibabel as nib import numpy as np def nifti_to_npy(nifti_path, output_dir): img = nib.load(nifti_path) data = img.get_fdata() for i in range(data.shape[-1]): slice_data = data[:, :, i] np.save(f"{output_dir}/slice_{i}.npy", slice_data) # 调用函数实例化转化过程 nifti_to_npy('path/to/your/nifti/file.nii.gz', 'output/directory') ``` 此段代码展示了如何读取 NIfTI 文件并将每一层存储为单独的 NumPy数组文件。 #### 自定义数据集适配 为了让 SAM 接受新类型的数据作为输入,通常还需要编写额外的代码片段用于构建适合该架构使用的 Dataset 类以及 DataLoader 实例对象。这一步骤涉及到定义具体的提示机制——即通过点、边界框或者两者的组合向网络提供指导信息[^3]。 ```python from torch.utils.data import Dataset, DataLoader import os import numpy as np class CustomDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.files = [f for f in os.listdir(root_dir) if f.endswith('.npy')] def __len__(self): return len(self.files) def __getitem__(self, idx): file_name = os.path.join(self.root_dir, self.files[idx]) image = np.load(file_name).astype(np.float32) sample = {'image': image} if self.transform: sample = self.transform(sample) return sample dataset = CustomDataset(root_dir='processed_slices/') dataloader = DataLoader(dataset, batch_size=4, shuffle=True) ``` 上述代码实现了简单的 PyTorch 风格的数据集类,并设置了基本参数。 #### 微调 SAM 模型 最后,在完成以上准备工作之后就可以着手修改源码中的超参设置部分了;比如改变默认的学习率、迭代次数等选项,从而更好地适应当前任务需求。同时也要注意检查官方文档获取更多有关 API 的细节说明。 ```yaml train: epochs: 50 lr_scheduler: type: CosineAnnealingLR T_max: ${train.epochs} eta_min: 0.0 optimizer: name: AdamW params_lr: 0.0001 ``` 这里给出了一种常见的优化器配置方案样例。
评论 30
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

听风吹等浪起

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值