SAM用于医学公共数据集———CHAOS(CT部分)—预处理
提示:这里可以添加系列文章的所有文章的目录,目录需要自己手动添加
例如:第一章 Python 机器学习入门之pandas的使用
提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档
前言
提示:这里可以添加本文要记录的大概内容:
本文旨在将该论文[Segment Anything Model for Medical Images?]中的实验进行复现, GitHub源码。在此代码基础上进行相应的修改来满足对CHAOS数据集的适应,主要内容就是CHAOS数据在该GitHub上的预处理过程。
PS: 本文代码进处理了CHAOS/CT/1这个文件夹的dcm文件。 后续再发一个处理所有文件(即CT下所有以数字命名的文件)的预处理过程。
提示:以下是本篇文章正文内容,下面案例可供参考
一、数据预处理
Step 1:本文选用 小赛看看来对.dcm数据进行可视化与预处理操作。
首先将数据在小赛看看中打开并可视化,如下图所示。
随后,将DCM导出为png格式并存储在对应路径下中, 本文所存储的是images文件夹下。
小赛看看中 DCM文件导出为PNG文件操作如图所示:
Step 2:将Gt与images的名称进行对应操作,使得在编码中方便操作。
代码如下:
import os
import re
# 定义图片所在路径
image_path = r"D:\深度学习pytorch课程相关代码练习\数据集\医学图像分割数据集\CHAOS\CHAOS_Train\Train_Sets\CT\19\images"
# 遍历指定路径下的所有文件
for filename in os.listdir(image_path):
if filename.endswith('.png'):
# 提取文件名前缀
front = os.path.splitext(filename)[0]
# 使用'-'分割前缀,并获取最后一部分的后三位
last_part = front.split('-')[-1]
# 尝试将最后一部分转为整数,如果转换失败则跳过该文件
try:
A = int(last_part)
except ValueError:
continue
# 生成新的文件名
new_filename = 'liver_Img_{:03d}.png'.format(A - 1)
# 拼接完整的旧文件路径和新文件路径
old_file_path = os.path.join(image_path, filename)
new_file_path = os.path.join(image_path, new_filename)
# 重命名文件
os.rename(old_file_path, new_file_path)
print(f"Renamed '{filename}' to '{new_filename}'")
print("All files have been renamed successfully.")
Step 3:将CHAOS中2D数据进行预处理。主要修改的是GitHub库中’pre_grey_rgb2D.py‘的部分代码。
首先是在if __name__ == "__main__"
中进行了如下修改
# get all the names of the images in the ground truth folder
if mode == 'train' or mode == 'valid':
# 加载所有CT数据的gt label
# names = [gt for root, dirs, files in os.walk(args.gt_path) for gt in files if "Ground" in root]
# 只加载CT下1文件夹中所有的gt label,方便测试
names = {"idx":int , "gt":[]} # idx用于表示CHAOS数据集下对应的gt label所属的文件夹名,用于在读取image时提供路径帮助
File_folders = sorted(i for i in os.listdir(args.gt_path))
for i in File_folders:
if i=="1":
for folder in os.listdir(os.path.join(args.gt_path, i)):
if folder=="Ground":
for gt in os.listdir(os.path.join(args.gt_path, i, folder)):
names["idx"] = i
names["gt"].append(gt)
# save
save_path = join(save_base, mode, args.task_name)
而对于循环语句 for gt_name in tqdm(names["gt"]):
,我添加了两个变量 interval_img_path 与 interval_gt_path,分别用于来补全从*‘CHAOS\CHAOS_Train\Train_Sets\CT’路径到‘CHAOS\CHAOS_Train\Train_Sets\CT\1\images(Ground)’路径* 的信息。代码如下:
# 它检查是否已经存在处理后的 .npz 文件.
# 如果不存在则执行处理过程,并将处理后的数据保存为 .npz 文件。
for gt_name in tqdm(names["gt"]):
if os.path.exists(join(save_path, gt_name.split('.')[0] + ".npz")):
continue
img_name = gt_name.replace('liver_GT_','liver_Img_') # 构建了img与gt的对应关系, 在Step 1中对应
interval_img_path = names["idx"] + "\\images"
interval_gt_path = names["idx"] + "\\Ground"
image_path = os.path.join(args.img_path, interval_img_path, img_name)
if not os.path.exists(image_path):
continue
gt_, new_lab_list, img_embedding, resized_size_be_padding, image_ori_size = process(gt_name, interval_gt_path, interval_img_path, img_name, mode=mode)
if gt_ is not None:
np.savez_compressed(
join(save_path, gt_name.split('.')[0] + ".npz"),
label_except_bk=new_lab_list,
gts=gt_,
img_embeddings=img_embedding,
image_shape=image_ori_size,
resized_size_before_padding=resized_size_be_padding
)
print("Num. of processed train images (delete images with no any targets):", len(imgs))
与for循环与剧中gt_, new_lab_list, img_embedding, resized_size_be_padding, image_ori_size = process(gt_name, interval_gt_path, interval_img_path, img_name, mode=mode)
对应的def process(gt_name: str, interval_gt_path:str, interval_img_path:str, image_name: str, mode: str):
这个函数也进行了如下修改,主要就是把这两个变量 interval_img_path 与 interval_gt_path分别在相应的路径上进行了补全。中间省略了部分未修改的代码,大家只看对应出的修改即可。
# 第一处补全
if image_name == None:
image_name = gt_name.split(".")[0] + args.img_name_suffix
if mode == "train":
gt_data = io.imread(join(args.gt_path, interval_gt_path, gt_name)) # H, W
elif mode == "valid":
gt_data = io.imread(join(args.gt_path.replace("train", "valid"), interval_gt_path, gt_name))
else:
gt_path = f"data/test_data/{args.task_name}/labels"
gt_data = io.imread(join(gt_path, interval_gt_path, gt_name))
... ... ...
# 第二处补全
if np.sum(gt) > 0: # has at least 1 object
# gt: seperate each target into size (B, H, W) binary 0-1 uint8
new_lab_list = list(np.unique(gt))[1:] # except bk
new_lab_list.sort()
gt_ = []
for l in new_lab_list:
gt_.append((gt == l) + 0)
gt_ = np.array(gt_, dtype=np.uint8)
if mode == "train":
image_data = io.imread(join(args.img_path, interval_img_path, image_name))
elif mode == "valid":
image_data = io.imread(join(args.img_path.replace("train", "valid"), interval_img_path, image_name))
else:
img_path = f"data/test_data/{args.task_name}/images"
image_data = io.imread(join(img_path, interval_img_path, image_name))
... ... ...
····经过以上几个步骤之后,我们运行修改过的**“pre_grey_rgb2D.py”**的文件,运行出来的.npz文件将保存在该路径下"Segment-Anything-Model-for-Medical-Images-main\data\precompute_vit_b\train\CHAOS_CT"。
本文预处理部分完整代码如下:
# %% import packages
import numpy as np
import os
from glob import glob
import pandas as pd
join = os.path.join
from skimage import transform, io, segmentation
from tqdm import tqdm
import torch
from torchvision.transforms.functional import InterpolationMode
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
import argparse
# set up the parser
parser = argparse.ArgumentParser(description="preprocess grey and RGB images")
# add arguments to the parser
parser.add_argument(
"-i",
"--img_path",
type=str,
default=f"../../../数据集/医学图像分割数据集/CHAOS/CHAOS_Train/Train_Sets/CT",
help="path to the images",
)
parser.add_argument(
"-gt",
"--gt_path",
type=str,
default=f"../../../数据集/医学图像分割数据集/CHAOS/CHAOS_Train/Train_Sets/CT",
help="path to the ground truth (gt)",
)
parser.add_argument(
"-task",
"--task_name",
type=str,
default=f"CHAOS_CT",
help="name to test dataset",
)
parser.add_argument(
"--csv",
type=str,
default=None,
help="path to the csv file",
)
parser.add_argument(
"-o",
"--npz_path",
type=str,
default=f"data",
help="path to save the npz files",
)
parser.add_argument(
"--data_name",
type=str,
default="demo2d",
help="dataset name; used to name the final npz file, e.g., demo2d.npz",
)
parser.add_argument("--image_size", type=int, default=1024, help="image size")
parser.add_argument(
"--img_name_suffix", type=str, default=".png", help="image name suffix"
)
# parser.add_argument("--label_id", type=int, default=255, help="label id")
parser.add_argument("--model_type", type=str, default="vit_b", help="model type")
parser.add_argument(
"--checkpoint",
type=str,
default="sam_vit_b_01ec64.pth",
help="original sam checkpoint",
)
parser.add_argument("--device", type=str, default="cuda:0", help="device")
parser.add_argument("--seed", type=int, default=2023, help="random seed")
# parse the arguments
args = parser.parse_args()
# create a directory to save the npz files
save_base = args.npz_path + "/precompute_" + args.model_type
# convert 2d grey or rgb images to npz file
imgs = []
gts = []
img_embeddings = []
# set up the model
# get the model from sam_model_registry using the model_type argument
# and load it with checkpoint argument
# download save the SAM checkpoint.
# [https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth](VIT-B SAM model)
sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint, device = args.device).to(
args.device
)
# ResizeLongestSide (1024), including image and gt
sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size)
def process(gt_name: str, interval_gt_path:str, interval_img_path:str, image_name: str, mode: str):
if image_name == None:
image_name = gt_name.split(".")[0] + args.img_name_suffix
if mode == "train":
gt_data = io.imread(join(args.gt_path, interval_gt_path, gt_name)) # H, W
elif mode == "valid":
gt_data = io.imread(join(args.gt_path.replace("train", "valid"), interval_gt_path, gt_name))
else:
gt_path = f"data/test_data/{args.task_name}/labels"
gt_data = io.imread(join(gt_path, interval_gt_path, gt_name))
# 若是RGN图像则选择第1个通道
if len(gt_data.shape) == 3:
gt_data = gt_data[:, :, 0]
assert len(gt_data.shape) == 2, "ground truth should be 2D"
# resize ground truth image
# resize_gt = sam_transform.apply_image(gt_data, interpolation=InterpolationMode.NEAREST) # ResizeLong (resized_h, 1024)
# gt_data = sam_model.preprocess_for_gt(resize_gt)
# exclude tiny objects (considering multi-object),
# 将2d图像中像素值<50的图像去除,全设为0。保证处理后的label对象只含有大于50像素的对象。
gt = gt_data.copy()
label_list = np.unique(gt_data)[1:]
del_lab = [] # for check
for label in label_list:
gt_single = (gt_data == label) + 0
if np.sum(gt_single) <= 50:
gt[gt == label] = 0
del_lab.append(label)
assert len(list(np.unique(gt)) + del_lab) == len(list(label_list) + [0])
# 对图像进行预处理
if np.sum(gt) > 0: # has at least 1 object
# gt: seperate each target into size (B, H, W) binary 0-1 uint8
new_lab_list = list(np.unique(gt))[1:] # except bk
new_lab_list.sort()
gt_ = []
for l in new_lab_list:
gt_.append((gt == l) + 0)
gt_ = np.array(gt_, dtype=np.uint8)
if mode == "train":
image_data = io.imread(join(args.img_path, interval_img_path, image_name))
elif mode == "valid":
image_data = io.imread(join(args.img_path.replace("train", "valid"), interval_img_path, image_name))
else:
img_path = f"data/test_data/{args.task_name}/images"
image_data = io.imread(join(img_path, interval_img_path, image_name))
image_ori_size = image_data.shape[:2]
# Remove any alpha channel if present.
if image_data.shape[-1] > 3 and len(image_data.shape) == 3:
image_data = image_data[:, :, :3]
# If image is grayscale, then repeat the last channel to convert to rgb
if len(image_data.shape) == 2:
image_data = np.repeat(image_data[:, :, None], 3, axis=-1)
# nii preprocess start (clip the intensity)
lower_bound, upper_bound = np.percentile(image_data, 0.95), np.percentile(
image_data, 99.5 # Intensity of 0.95% pixels in image_data lower than lower_bound
# Intensity of 99.5% pixels in image_data lower than upper_bound
)
image_data_pre = np.clip(image_data, lower_bound, upper_bound)
# min-max normalize and scale
image_data_pre = (
(image_data_pre - np.min(image_data_pre))
/ (np.max(image_data_pre) - np.min(image_data_pre))
* 255.0
)
image_data_pre[image_data == 0] = 0 # ensure 0-255
image_data_pre = np.uint8(image_data_pre)
imgs.append(image_data_pre)
# resize image to 3*1024*1024
resize_img = sam_transform.apply_image(image_data_pre, interpolation=InterpolationMode.BILINEAR) # ResizeLong
resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1))[None, :, :, :].to(
args.device
) # (1, 3, resized_h, 1024)
resized_size_be_padding = tuple(resize_img_tensor.shape[-2:])
input_image = sam_model.preprocess(resize_img_tensor) # padding to (1, 3, 1024, 1024)
assert input_image.shape == (
1,
3,
sam_model.image_encoder.img_size,
sam_model.image_encoder.img_size,
), "input image should be resized to 1024*1024"
assert input_image.shape[-2:] == (1024, 1024)
# pre-compute the image embedding
if mode != "train":
sam_model.eval()
with torch.no_grad():
embedding = sam_model.image_encoder(input_image)
img_embedding = embedding.cpu().numpy()[0]
return gt_, new_lab_list, img_embedding, resized_size_be_padding, image_ori_size
else:
print(mode, gt_name)
return None, None, None, None, None
if __name__ == "__main__":
mode = 'train' # train
if args.csv != None:
# if data is presented in csv format
# columns must be named image_filename and mask_filename respectively
try:
os.path.exists(args.csv)
except FileNotFoundError as e:
print(f"File {args.csv} not found!!")
df = pd.read_csv(args.csv)
bar = tqdm(df.iterrows(), total=len(df))
for idx, row in bar:
process(row.mask_filename, row.image_filename)
else:
# get all the names of the images in the ground truth folder
if mode == 'train' or mode == 'valid':
# 加载所有CT数据的gt label
# names = [gt for root, dirs, files in os.walk(args.gt_path) for gt in files if "Ground" in root]
# 只加载CT下1文件夹中所有的gt label,方便测试
names = {"idx":int , "gt":[]} # idx用于表示CHAOS数据集下对应的gt label所属的文件夹名,用于在读取image时提供路径帮助
File_folders = sorted(i for i in os.listdir(args.gt_path))
for i in File_folders:
if i=="1":
for folder in os.listdir(os.path.join(args.gt_path, i)):
if folder=="Ground":
for gt in os.listdir(os.path.join(args.gt_path, i, folder)):
names["idx"] = i
names["gt"].append(gt)
# save
save_path = join(save_base, mode, args.task_name)
else:
gt_path = f"{args.gt_path}/{args.task_name}/labels"
args.img_path = f"{args.gt_path}/{args.task_name}/images"
names = sorted(os.listdir(gt_path))
# save
save_path = join(save_base, mode,args.task_name)
# print the number of images found in the ground truth folder
print("Num. of all train images:", len(names["gt"]))
os.makedirs(save_path, exist_ok=True)
# 它检查是否已经存在处理后的 .npz 文件.
# 如果不存在则执行处理过程,并将处理后的数据保存为 .npz 文件。
for gt_name in tqdm(names["gt"]):
if os.path.exists(join(save_path, gt_name.split('.')[0] + ".npz")):
continue
img_name = gt_name.replace('liver_GT_','liver_Img_') # 构建了img与gt的对应关系
interval_img_path = names["idx"] + "\\images"
interval_gt_path = names["idx"] + "\\Ground"
image_path = os.path.join(args.img_path, interval_img_path, img_name)
if not os.path.exists(image_path):
continue
gt_, new_lab_list, img_embedding, resized_size_be_padding, image_ori_size = process(gt_name, interval_gt_path, interval_img_path, img_name, mode=mode)
if gt_ is not None:
np.savez_compressed(
join(save_path, gt_name.split('.')[0] + ".npz"),
label_except_bk=new_lab_list,
gts=gt_,
img_embeddings=img_embedding,
image_shape=image_ori_size,
resized_size_before_padding=resized_size_be_padding
)
print("Num. of processed train images (delete images with no any targets):", len(imgs))