不同数据集的处理方式
在上一个博客中,我提供了四个数据集的.h5版本的数据,这个版本的数据可以直接拿来用来进行模型训练。但是鉴于可能有朋友没有下载到我提供的数据,或者是想自己对数据进行预处理,我这里也提供一下LA,ACDC和PROMISE12数据集的数据处理方式。
LA 数据集
这里LA数据集的原始数据类型是lgemri.nrrd。我们需要将其处理为.h5数据类型。一般在半监督医学图像分割任务中,LA数据集的数据大小都会被处理为[112, 112, 80]。然后对于每一个实例,处理好的最终的文件名叫做mri_norm2.h5。
import numpy as np
from glob import glob
from tqdm import tqdm
import h5py
import nrrd
output_size =[112, 112, 80]
def covert_h5():
listt = glob('../data/LA/2018LA_Seg_Training Set/*/lgemri.nrrd')
for item in tqdm(listt):
image, img_header = nrrd.read(item)
label, gt_header = nrrd.read(item.replace('lgemri.nrrd', 'laendo.nrrd'))
label = (label == 255).astype(np.uint8)
w, h, d = label.shape
tempL = np.nonzero(label)
minx, maxx = np.min(tempL[0]), np.max(tempL[0])
miny, maxy = np.min(tempL[1]), np.max(tempL[1])
minz, maxz = np.min(tempL[2]), np.max(tempL[2])
px = max(output_size[0] - (maxx - minx), 0) // 2
py = max(output_size[1] - (maxy - miny), 0) // 2
pz = max(output_size[2] - (maxz - minz), 0) // 2
minx = max(minx - np.random.randint(10, 20) - px, 0)
maxx = min(maxx + np.random.randint(10, 20) + px, w)
miny = max(miny - np.random.randint(10, 20) - py, 0)
maxy = min(maxy + np.random.randint(10, 20) + py, h)
minz = max(minz - np.random.randint(5, 10) - pz, 0)
maxz = min(maxz + np.random.randint(5, 10) + pz, d)
image = (image - np.mean(image)) / np.std(image)
image = image.astype(np.float32)
image = image[minx:maxx, miny:maxy]
label = label[minx:maxx, miny:maxy]
print(label.shape)
f = h5py.File(item.replace('lgemri.nrrd', 'mri_norm2.h5'), 'w')
f.create_dataset('image', data=image, compression="gzip")
f.create_dataset('label', data=label, compression="gzip")
f.close()
if __name__ == '__main__':
covert_h5()
ACDC 数据集
可以看到,这里ACDC数据集的原始数据类型是.nii.gz。我们的目标也是将其转化为.h5文件。但是与LA数据集不同的是,一个ACDC volumes包含不同的frame,然后一个frame又包含了多个2D slices。所以对于每一个ACDC frame,会为每一个slice生成一个.h5文件。可以看到,在代码中,对应两个不同的for循环,第一个for循环是为了遍历每一个ACDC volume,而第二个for循环对应的是一个ACDC frame包含的多个切片。这里ACDC数据集的数据类型和Pancreas数据集是一样的,所以可以参考这个处理方式对Pancreas数据集进行处理。
import glob
import os
import h5py
import numpy as np
import SimpleITK as sitk
slice_num = 0
mask_path = sorted(glob.glob("/home/xdluo/data/ACDC/image/*.nii.gz"))
for case in mask_path:
img_itk = sitk.ReadImage(case)
origin = img_itk.GetOrigin()
spacing = img_itk.GetSpacing()
direction = img_itk.GetDirection()
image = sitk.GetArrayFromImage(img_itk)
msk_path = case.replace("image", "label").replace(".nii.gz", "_gt.nii.gz")
if os.path.exists(msk_path):
print(msk_path)
msk_itk = sitk.ReadImage(msk_path)
mask = sitk.GetArrayFromImage(msk_itk)
image = (image - image.min()) / (image.max() - image.min())
print(image.shape)
image = image.astype(np.float32)
item = case.split("/")[-1].split(".")[0]
if image.shape != mask.shape:
print("Error")
print(item)
for slice_ind in range(image.shape[0]):
f = h5py.File(
'/home/xdluo/data/ACDC/data/{}_slice_{}.h5'.format(item, slice_ind), 'w')
f.create_dataset(
'image', data=image[slice_ind], compression="gzip")
f.create_dataset('label', data=mask[slice_ind], compression="gzip")
f.close()
slice_num += 1
print("Converted all ACDC volumes to 2D slices")
print("Total {} slices".format(slice_num))
PROMISE12 数据集
这个数据集的处理方式与ACDC相似,因为这里每一个prostate的实例也包含了多个切片。
import os
import SimpleITK as sitk
import h5py
import matplotlib.pyplot as plt
import numpy as np
path = "../../data/Prostate"
with open(path + '/all.list', 'r') as f1:
train_list = f1.readlines()
train_list = [item.replace('\n', '') for item in train_list]
for image_name in train_list:
image = sitk.ReadImage(path + '/training_data/' + image_name + '.mhd')
label = sitk.ReadImage(path + '/training_data/' + image_name + '_segmentation.mhd')
image_array = sitk.GetArrayFromImage(image)
label_array = sitk.GetArrayFromImage(label)
image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min())
# image = image.astype(np.float32)
selected_h5 = h5py.File(path+"/data/"+image_name+".h5", "w")
selected_h5.create_dataset("image", data=image_array)
selected_h5.create_dataset("label", data=label_array)
selected_h5.close()
for i in range(image_array.shape[0]):
selected_image_slice = image_array[i]
selected_label_slice = label_array[i]
selected_slice_h5 = h5py.File(path+"/data/slices/"+image_name + "_slice_" + str(i) + ".h5", "w")
selected_slice_h5.create_dataset("image", data=selected_image_slice)
selected_slice_h5.create_dataset("label", data=selected_label_slice)
selected_slice_h5.close()