SAM-Med2D 大模型学习笔记(续):训练自己数据集

1、前言、数据集介绍

SAM-Med2D大模型介绍参考上文:第三章:SAM-Med2D大模型复现-CSDN博客

本文将使用SAM-Med2D大模型训练自己的数据集

关于SAM-Med2D大模型官方demo数据集的介绍上文已经介绍过,这里简单回顾下

  • 其中data_demo为数据集的目录,下面有images和masks两个目录,分别存放数据和标签
  • 其中images,就是正常的数据图像,格式是png格式
  • masks格式值得注意,正常的mask是灰度等级的阈值图像【0 1 2 3】,这里把每个类别单独提取出来,变成【0 255】的二值图像,有几个类别就有几张对应的mask模板

例如mask是【0 1 2 2 1】,mask模板有两个,分别是1对应的模板【0 255 0 0 255】,就是只分割前景1。以及只是分割2的模板【0 0 255 255 0】。

mask的命名可以是image名字加上灰度,例如image_1.png和image_2.png

两个json文件如下:

训练数据就是单张image对应的一组mask标签字典

测试集是mask对应的image

2、生成数据的脚本

有了上面介绍,就很简单了,也就是说我们只需要把自己的数据集换成上面格式就可以正常训练了!

其他补充,因为官方的image和mask都是png格式的。

格式需要是png,因为之前本人做过实验只有png保存的二值图像,灰度值才不会乱掉(比如你保存【0 255 0】的jpg读取,np.unique读取可能变成【0 224 223】之类的)

更改文件后缀可以参考:PYTHON 自动化办公:更改图片后缀_改变文件夹里面图片后缀名的pytorch代码-CSDN博客

这里需要把自己的数据集摆放如下:

划分数据集的脚本参考:关于图像分割任务中按照比例将数据集随机划分成训练集和测试集_图像数据划分训练集-CSDN博客

然后运行下面代码就行了:

这个代码会生成image对应mask不同类别的掩膜数据,并且生成两个json文件。这里的目录命名一定要和上面对应

import json
import numpy as np
from tqdm import tqdm
import os
import shutil
from PIL import Image
import cv2


def mkdir():
    root = 'data_demo'
    if os.path.exists(root):
        shutil.rmtree(root)
    os.mkdir(root)
    os.mkdir(os.path.join(root,'images'))
    os.mkdir(os.path.join(root,'masks'))


# 生成训练集
def gen_trainSet(img_suff,msk_suff):
    p = 'RawData/train/images'
    image_list = [os.path.join(p,i) for i in os.listdir(p)]

    with open('data_demo/image2label_train.json', 'a') as jf:
        json_all ={}        # json文件
        for i in tqdm(image_list,desc='generate train set'):
            j = i.replace('images','masks').replace(img_suff,msk_suff)
            assert os.path.exists(j)        # 判断label是否存在

            shutil.copy(i,'data_demo/images')

            mask = np.array(Image.open(j).convert('L'))     # 标签图像
            gray_list = np.unique(mask)

            img_list = []
            for gray in gray_list[1:]:          # 遍历mask所有的分割前景
                ret_mask = np.zeros(mask.shape,dtype=np.uint8)

                ret_mask[mask==gray] =255      # 指定前景为255,其余为背景
                ret_mask[ret_mask<255] = 0

                # 去除小的分割区域
                h,w = ret_mask.shape
                total_pixel = h*w
                if (np.sum(ret_mask!=0)/total_pixel) < 0.005:
                    continue

                ret_name =i.replace(img_suff,'_'+str(gray)+img_suff).replace('RawData/train/images','data_demo/masks')
                cv2.imwrite(ret_name,ret_mask)  # 保存生成的数据

                img_list.append(ret_name)
            if len(img_list) == 0:
                continue
            json_all[i.replace('RawData/train/images','data_demo/images')] = img_list

        json_str = json.dumps(json_all,indent=4)
        jf.write(json_str)


# 生成测试集
def gen_testSet(img_suff,msk_suff):
    p = 'RawData/test/images'
    image_list = [os.path.join(p,i) for i in os.listdir(p)]

    with open('data_demo/label2image_test.json', 'a') as jf:
        json_all ={}        # json文件
        for i in tqdm(image_list,desc='generate test set'):
            j = i.replace('images','masks').replace(img_suff,msk_suff)
            assert os.path.exists(j)        # 判断label是否存在

            shutil.copy(i,'data_demo/images')

            mask = np.array(Image.open(j).convert('L'))     # 标签图像
            gray_list = np.unique(mask)

            for gray in gray_list[1:]:          # 遍历mask所有的分割前景
                ret_mask = np.zeros(mask.shape,dtype=np.uint8)

                ret_mask[mask==gray] =255      # 指定前景为255,其余为背景
                ret_mask[ret_mask<255] = 0

                # 去除小的分割区域
                h,w = ret_mask.shape
                total_pixel = h*w
                if (np.sum(ret_mask!=0)/total_pixel) < 0.005:
                    continue

                ret_name =i.replace(img_suff,'_'+str(gray)+img_suff).replace('RawData/test/images','data_demo/masks')
                cv2.imwrite(ret_name,ret_mask)  # 保存生成的数据

                json_all[ret_name] = i.replace('RawData/test/images','data_demo/images')

        json_str = json.dumps(json_all,indent=4)
        jf.write(json_str)


if __name__ == '__main__':
    imgFormat = '.png'          # image 的后缀
    maskFormat = '.png'         # mask 的后缀

    mkdir()         # 生成目录

    gen_trainSet(img_suff=imgFormat,msk_suff=maskFormat)        # 生成训练数据

    gen_testSet(img_suff=imgFormat,msk_suff=maskFormat)         # 生成测试数据

Tips

运行过程如下

如下:

可以看到image生成了三个对应的mask数据,命名是image的名字加上类别。

下图的8 9 17后缀是原来mask中8 9 17的像素值

测试代码的时候,训练会报错误,大概是len(box)什么分母为零,不能被除的bug。本人猜测可能是生成的组mask里面,前景区域太小之类的,所有脚本里增加点处理

代码会将不足千分之五的分割前景区域删除

3、训练脚本

因为生成的数据就是data_demo目录,所有train脚本不需要任何更改,直接运行即可

这里的parser.add_argument("--mask_num", type=int, default=5, help="get mask number")参数还是没懂

生成的结果如下:每个权重大约2G左右吧

4、测试脚本

代码如下:

python test.py --sam_checkpoint workdir/models/sam-med2d/epoch10_sam.pth

测试结果如下:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

听风吹等浪起

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值