【SAM用于医学公共数据集———CHAOS(CT部分)】

SAM用于医学公共数据集———CHAOS(CT部分)—预处理

提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加
例如:第一章 Python 机器学习入门之pandas的使用


提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

提示:这里可以添加本文要记录的大概内容:

本文旨在将该论文[Segment Anything Model for Medical Images?]中的实验进行复现, GitHub源码。在此代码基础上进行相应的修改来满足对CHAOS数据集的适应,主要内容就是CHAOS数据在该GitHub上的预处理过程。
PS: 本文代码进处理了CHAOS/CT/1这个文件夹的dcm文件。 后续再发一个处理所有文件(即CT下所有以数字命名的文件)的预处理过程。


提示:以下是本篇文章正文内容,下面案例可供参考

一、数据预处理

Step 1:本文选用 小赛看看来对.dcm数据进行可视化与预处理操作。
首先将数据在小赛看看中打开并可视化,如下图所示。DCM数据可视化
随后,将DCM导出为png格式并存储在对应路径下中, 本文所存储的是images文件夹下。
DCM2PNG 存储路径
小赛看看中 DCM文件导出为PNG文件操作如图所示:
STEP1
STEP2&3
STEP4
Step 2:将Gt与images的名称进行对应操作,使得在编码中方便操作。
代码如下:

import os
import re

# 定义图片所在路径
image_path = r"D:\深度学习pytorch课程相关代码练习\数据集\医学图像分割数据集\CHAOS\CHAOS_Train\Train_Sets\CT\19\images"

# 遍历指定路径下的所有文件
for filename in os.listdir(image_path):
    if filename.endswith('.png'):
        # 提取文件名前缀
        front = os.path.splitext(filename)[0]
        
        # 使用'-'分割前缀,并获取最后一部分的后三位
        last_part = front.split('-')[-1]
        
        # 尝试将最后一部分转为整数,如果转换失败则跳过该文件
        try:
            A = int(last_part)
        except ValueError:
            continue
        
        # 生成新的文件名
        new_filename = 'liver_Img_{:03d}.png'.format(A - 1)
        
        # 拼接完整的旧文件路径和新文件路径
        old_file_path = os.path.join(image_path, filename)
        new_file_path = os.path.join(image_path, new_filename)
        
        # 重命名文件
        os.rename(old_file_path, new_file_path)
        print(f"Renamed '{filename}' to '{new_filename}'")

print("All files have been renamed successfully.")

Step 3:将CHAOS中2D数据进行预处理。主要修改的是GitHub库中’pre_grey_rgb2D.py‘的部分代码。
首先是在if __name__ == "__main__"中进行了如下修改

 # get all the names of the images in the ground truth folder
        if mode == 'train' or mode == 'valid':
            #  加载所有CT数据的gt label
            #  names = [gt for root, dirs, files in os.walk(args.gt_path) for gt in files if "Ground" in root]
            
            #  只加载CT下1文件夹中所有的gt label,方便测试
            names = {"idx":int , "gt":[]}  # idx用于表示CHAOS数据集下对应的gt label所属的文件夹名,用于在读取image时提供路径帮助
            File_folders = sorted(i for i in os.listdir(args.gt_path))
            for i in File_folders:
                if i=="1":
                    for folder in os.listdir(os.path.join(args.gt_path, i)):
                        if folder=="Ground":
                            for gt in os.listdir(os.path.join(args.gt_path, i, folder)):
                                names["idx"] = i
                                names["gt"].append(gt)
            # save
            save_path = join(save_base, mode, args.task_name)

而对于循环语句 for gt_name in tqdm(names["gt"]):,我添加了两个变量 interval_img_path 与 interval_gt_path,分别用于来补全从*‘CHAOS\CHAOS_Train\Train_Sets\CT’路径‘CHAOS\CHAOS_Train\Train_Sets\CT\1\images(Ground)’路径* 的信息。代码如下:

   # 它检查是否已经存在处理后的 .npz 文件.
        # 如果不存在则执行处理过程,并将处理后的数据保存为 .npz 文件。
        for gt_name in tqdm(names["gt"]):
            if os.path.exists(join(save_path, gt_name.split('.')[0] + ".npz")):
                continue
            img_name = gt_name.replace('liver_GT_','liver_Img_')  # 构建了img与gt的对应关系, 在Step 1中对应
            interval_img_path = names["idx"] + "\\images"
            interval_gt_path = names["idx"] + "\\Ground"
            image_path = os.path.join(args.img_path, interval_img_path, img_name)
            if not os.path.exists(image_path):
                continue
            gt_, new_lab_list, img_embedding, resized_size_be_padding, image_ori_size = process(gt_name, interval_gt_path, interval_img_path, img_name, mode=mode)
            if gt_ is not None:
                np.savez_compressed(
                    join(save_path, gt_name.split('.')[0] + ".npz"),
                    label_except_bk=new_lab_list,
                    gts=gt_,
                    img_embeddings=img_embedding,
                    image_shape=image_ori_size,
                    resized_size_before_padding=resized_size_be_padding
                )
        print("Num. of processed train images (delete images with no any targets):", len(imgs))

与for循环与剧中gt_, new_lab_list, img_embedding, resized_size_be_padding, image_ori_size = process(gt_name, interval_gt_path, interval_img_path, img_name, mode=mode)对应的def process(gt_name: str, interval_gt_path:str, interval_img_path:str, image_name: str, mode: str):这个函数也进行了如下修改,主要就是把这两个变量 interval_img_path 与 interval_gt_path分别在相应的路径上进行了补全。中间省略了部分未修改的代码,大家只看对应出的修改即可。

  # 第一处补全
   if image_name == None:
        image_name = gt_name.split(".")[0] + args.img_name_suffix
    if mode == "train":
        gt_data = io.imread(join(args.gt_path, interval_gt_path, gt_name)) # H, W
    elif mode == "valid":
        gt_data = io.imread(join(args.gt_path.replace("train", "valid"), interval_gt_path, gt_name))
    else:
        gt_path = f"data/test_data/{args.task_name}/labels"
        gt_data = io.imread(join(gt_path, interval_gt_path, gt_name))
   ... ... ...
   # 第二处补全
     if np.sum(gt) > 0: # has at least 1 object
        # gt: seperate each target into size (B, H, W) binary 0-1 uint8
        new_lab_list = list(np.unique(gt))[1:] # except bk
        new_lab_list.sort()
        gt_ = []
        for l in new_lab_list:
            gt_.append((gt == l) + 0)
        gt_ = np.array(gt_, dtype=np.uint8)

        if mode == "train":
            image_data = io.imread(join(args.img_path, interval_img_path, image_name))
        elif mode == "valid":
            image_data = io.imread(join(args.img_path.replace("train", "valid"), interval_img_path, image_name))
        else:
            img_path = f"data/test_data/{args.task_name}/images"
            image_data = io.imread(join(img_path, interval_img_path, image_name))
    ... ... ...

····经过以上几个步骤之后,我们运行修改过的**“pre_grey_rgb2D.py”**的文件,运行出来的.npz文件将保存在该路径下"Segment-Anything-Model-for-Medical-Images-main\data\precompute_vit_b\train\CHAOS_CT"。
本文预处理部分完整代码如下:

  # %% import packages
import numpy as np
import os
from glob import glob
import pandas as pd

join = os.path.join
from skimage import transform, io, segmentation
from tqdm import tqdm
import torch
from torchvision.transforms.functional import InterpolationMode
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
import argparse

# set up the parser
parser = argparse.ArgumentParser(description="preprocess grey and RGB images")

# add arguments to the parser
parser.add_argument(
    "-i",
    "--img_path",
    type=str,
    default=f"../../../数据集/医学图像分割数据集/CHAOS/CHAOS_Train/Train_Sets/CT",  
    help="path to the images",
)
parser.add_argument(
    "-gt",
    "--gt_path",
    type=str,
    default=f"../../../数据集/医学图像分割数据集/CHAOS/CHAOS_Train/Train_Sets/CT",  
    help="path to the ground truth (gt)",
)

parser.add_argument(
    "-task",
    "--task_name",
    type=str,
    default=f"CHAOS_CT",
    help="name to test dataset",
)

parser.add_argument(
    "--csv",
    type=str,
    default=None,
    help="path to the csv file",
)

parser.add_argument(
    "-o",
    "--npz_path",
    type=str,
    default=f"data",
    help="path to save the npz files",
)
parser.add_argument(
    "--data_name",
    type=str,
    default="demo2d",
    help="dataset name; used to name the final npz file, e.g., demo2d.npz",
)
parser.add_argument("--image_size", type=int, default=1024, help="image size")
parser.add_argument(
    "--img_name_suffix", type=str, default=".png", help="image name suffix"
)
# parser.add_argument("--label_id", type=int, default=255, help="label id")
parser.add_argument("--model_type", type=str, default="vit_b", help="model type")
parser.add_argument(
    "--checkpoint",
    type=str,
    default="sam_vit_b_01ec64.pth",
    help="original sam checkpoint",
)
parser.add_argument("--device", type=str, default="cuda:0", help="device")
parser.add_argument("--seed", type=int, default=2023, help="random seed")

# parse the arguments
args = parser.parse_args()

# create a directory to save the npz files
save_base = args.npz_path + "/precompute_" + args.model_type

# convert 2d grey or rgb images to npz file
imgs = []
gts = []
img_embeddings = []

# set up the model
# get the model from sam_model_registry using the model_type argument
# and load it with checkpoint argument
# download save the SAM checkpoint.
# [https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth](VIT-B SAM model)

sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint, device = args.device).to(
    args.device
)

# ResizeLongestSide (1024), including image and gt
sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size)

def process(gt_name: str, interval_gt_path:str, interval_img_path:str, image_name: str, mode: str):
    if image_name == None:
        image_name = gt_name.split(".")[0] + args.img_name_suffix
    if mode == "train":
        gt_data = io.imread(join(args.gt_path, interval_gt_path, gt_name)) # H, W
    elif mode == "valid":
        gt_data = io.imread(join(args.gt_path.replace("train", "valid"), interval_gt_path, gt_name))
    else:
        gt_path = f"data/test_data/{args.task_name}/labels"
        gt_data = io.imread(join(gt_path, interval_gt_path, gt_name))
    # 若是RGN图像则选择第1个通道
    if len(gt_data.shape) == 3:
        gt_data = gt_data[:, :, 0]
    assert len(gt_data.shape) == 2, "ground truth should be 2D"

    # resize ground truth image
    # resize_gt = sam_transform.apply_image(gt_data, interpolation=InterpolationMode.NEAREST) # ResizeLong (resized_h, 1024)
    # gt_data = sam_model.preprocess_for_gt(resize_gt)

    # exclude tiny objects (considering multi-object),
    # 将2d图像中像素值<50的图像去除,全设为0。保证处理后的label对象只含有大于50像素的对象。
    gt = gt_data.copy()
    label_list = np.unique(gt_data)[1:]
    del_lab = [] # for check
    for label in label_list:
        gt_single = (gt_data == label) + 0
        if np.sum(gt_single) <= 50:
            gt[gt == label] = 0
            del_lab.append(label)
    assert len(list(np.unique(gt)) + del_lab) == len(list(label_list) + [0])

    # 对图像进行预处理
    if np.sum(gt) > 0: # has at least 1 object
        # gt: seperate each target into size (B, H, W) binary 0-1 uint8
        new_lab_list = list(np.unique(gt))[1:] # except bk
        new_lab_list.sort()
        gt_ = []
        for l in new_lab_list:
            gt_.append((gt == l) + 0)
        gt_ = np.array(gt_, dtype=np.uint8)

        if mode == "train":
            image_data = io.imread(join(args.img_path, interval_img_path, image_name))
        elif mode == "valid":
            image_data = io.imread(join(args.img_path.replace("train", "valid"), interval_img_path, image_name))
        else:
            img_path = f"data/test_data/{args.task_name}/images"
            image_data = io.imread(join(img_path, interval_img_path, image_name))
        image_ori_size = image_data.shape[:2]
        # Remove any alpha channel if present.
        if image_data.shape[-1] > 3 and len(image_data.shape) == 3:
            image_data = image_data[:, :, :3]
        # If image is grayscale, then repeat the last channel to convert to rgb
        if len(image_data.shape) == 2:
            image_data = np.repeat(image_data[:, :, None], 3, axis=-1)
        # nii preprocess start (clip the intensity)
        lower_bound, upper_bound = np.percentile(image_data, 0.95), np.percentile(
            image_data, 99.5 # Intensity of 0.95% pixels in image_data lower than lower_bound
                             # Intensity of 99.5% pixels in image_data lower than upper_bound
        )
        image_data_pre = np.clip(image_data, lower_bound, upper_bound)
        # min-max normalize and scale
        image_data_pre = (
            (image_data_pre - np.min(image_data_pre))
            / (np.max(image_data_pre) - np.min(image_data_pre))
            * 255.0
        )
        image_data_pre[image_data == 0] = 0 # ensure 0-255
        image_data_pre = np.uint8(image_data_pre)
        imgs.append(image_data_pre)

        # resize image to 3*1024*1024
        resize_img = sam_transform.apply_image(image_data_pre, interpolation=InterpolationMode.BILINEAR) # ResizeLong
        resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1))[None, :, :, :].to(
            args.device
        ) # (1, 3, resized_h, 1024)
        resized_size_be_padding = tuple(resize_img_tensor.shape[-2:])
        input_image = sam_model.preprocess(resize_img_tensor) # padding to (1, 3, 1024, 1024)
        assert input_image.shape == (
            1,
            3,
            sam_model.image_encoder.img_size,
            sam_model.image_encoder.img_size,
        ), "input image should be resized to 1024*1024"
        assert input_image.shape[-2:] == (1024, 1024)
        # pre-compute the image embedding
        if mode != "train":
            sam_model.eval()
        with torch.no_grad():
            embedding = sam_model.image_encoder(input_image)
            img_embedding = embedding.cpu().numpy()[0]
        return gt_, new_lab_list, img_embedding, resized_size_be_padding, image_ori_size
    else:
        print(mode, gt_name)
        return None, None, None, None, None

if __name__ == "__main__":
    mode = 'train'  # train
    if args.csv != None:
        # if data is presented in csv format
        # columns must be named image_filename and mask_filename respectively
        try:
            os.path.exists(args.csv)
        except FileNotFoundError as e:
            print(f"File {args.csv} not found!!")

        df = pd.read_csv(args.csv)
        bar = tqdm(df.iterrows(), total=len(df))
        for idx, row in bar:
            process(row.mask_filename, row.image_filename)

    else:
        # get all the names of the images in the ground truth folder
        if mode == 'train' or mode == 'valid':
            #  加载所有CT数据的gt label
            #  names = [gt for root, dirs, files in os.walk(args.gt_path) for gt in files if "Ground" in root]
            
            #  只加载CT下1文件夹中所有的gt label,方便测试
            names = {"idx":int , "gt":[]}  # idx用于表示CHAOS数据集下对应的gt label所属的文件夹名,用于在读取image时提供路径帮助
            File_folders = sorted(i for i in os.listdir(args.gt_path))
            for i in File_folders:
                if i=="1":
                    for folder in os.listdir(os.path.join(args.gt_path, i)):
                        if folder=="Ground":
                            for gt in os.listdir(os.path.join(args.gt_path, i, folder)):
                                names["idx"] = i
                                names["gt"].append(gt)
            # save
            save_path = join(save_base, mode, args.task_name)
        else:
            gt_path = f"{args.gt_path}/{args.task_name}/labels"
            args.img_path = f"{args.gt_path}/{args.task_name}/images"
            names = sorted(os.listdir(gt_path))
            # save
            save_path = join(save_base, mode,args.task_name)
        # print the number of images found in the ground truth folder
        print("Num. of all train images:", len(names["gt"]))
        
        os.makedirs(save_path, exist_ok=True)
        
        # 它检查是否已经存在处理后的 .npz 文件.
        # 如果不存在则执行处理过程,并将处理后的数据保存为 .npz 文件。
        for gt_name in tqdm(names["gt"]):
            if os.path.exists(join(save_path, gt_name.split('.')[0] + ".npz")):
                continue
            img_name = gt_name.replace('liver_GT_','liver_Img_')  # 构建了img与gt的对应关系
            interval_img_path = names["idx"] + "\\images"
            interval_gt_path = names["idx"] + "\\Ground"
            image_path = os.path.join(args.img_path, interval_img_path, img_name)
            if not os.path.exists(image_path):
                continue
            gt_, new_lab_list, img_embedding, resized_size_be_padding, image_ori_size = process(gt_name, interval_gt_path, interval_img_path, img_name, mode=mode)
            if gt_ is not None:
                np.savez_compressed(
                    join(save_path, gt_name.split('.')[0] + ".npz"),
                    label_except_bk=new_lab_list,
                    gts=gt_,
                    img_embeddings=img_embedding,
                    image_shape=image_ori_size,
                    resized_size_before_padding=resized_size_be_padding
                )
        print("Num. of processed train images (delete images with no any targets):", len(imgs))

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值