问题
受到的手工标注的标签数据mask.nii.gz和原来图像数据.nii.gz大小不匹配。
思路
在ITK分别打开两个数据发现,就是mask上的标签数据和image上的相应位置还是对应的。发现有它们大概是间隔不同。
代码
import SimpleITK as sitk
import os
import numpy as np
def resize_mask(reference_image, mask, save_image_path,save_mask_path):
# 读取mask和reference image
# 获取reference image的大小和空间信息
reference_size = reference_image.GetSize()
reference_spacing = reference_image.GetSpacing()
reference_origin = reference_image.GetOrigin()
reference_direction = reference_image.GetDirection()
if reference_size[1] == 256:
# 使用reference image的信息调整mask的大小
resampled_mask = sitk.Resample(mask, reference_size, sitk.Transform(), sitk.sitkNearestNeighbor,
reference_origin, reference_spacing, reference_direction)
sitk.WriteImage(resampled_mask, save_mask_path)
sitk.WriteImage(reference_image, save_image_path)
else:
resampled_image = sitk.Resample(reference_image, reference_size, sitk.Transform(), sitk.sitkNearestNeighbor,
reference_origin, reference_spacing, reference_direction)
resampled_mask = sitk.Resample(mask, reference_size, sitk.Transform(), sitk.sitkNearestNeighbor,
reference_origin, reference_spacing, reference_direction)
sitk.WriteImage(resampled_mask, save_mask_path)
sitk.WriteImage(resampled_image, save_image_path)
# 保存调整大小后的mask
def Croped(image, label):
label_array = sitk.GetArrayFromImage(label)
image_array = sitk.GetArrayFromImage(image)
size = image.GetSize() # Get the size of the image
center_x = size[0] // 2
center_y = size[1] // 2
center_z = size[2] // 2 # Calculate the center in xyz directions
print(center_x, center_y)
if center_x > 128:
image_array = image_array[center_z - 128:center_z + 128, center_x - 128:center_x + 128, center_y - 128:center_y + 128]
label_array = label_array[center_z - 128:center_z + 128, center_x - 128:center_x + 128, center_y - 128:center_y + 128]
else:
pad_width = ((0,0), (128 - center_x+1, 128 - center_x+1), (128 - center_x+1, 128 - center_x+1))
image_array = np.pad(image_array, pad_width, mode='constant')
label_array = np.pad(label_array, pad_width, mode='constant')
size = np.shape(image_array)
center_x = size[1] // 2
center_y = size[2] // 2
center_z = size[0] // 2 # Calculate the center in xyz directions
image_array = image_array[center_z - 128:center_z + 128, center_x - 128:center_x + 128, center_y - 128:center_y + 128]
label_array = label_array[center_z - 128:center_z + 128, center_x - 128:center_x + 128, center_y - 128:center_y + 128]
print(image_array.shape)
new_image = sitk.GetImageFromArray(image_array)
new_image.SetDirection(image.GetDirection())
new_image.SetOrigin(image.GetOrigin())
new_image.SetSpacing(image.GetSpacing())
new_seg = sitk.GetImageFromArray(label_array)
new_seg.SetDirection(label.GetDirection())
new_seg.SetOrigin(label.GetOrigin())
new_seg.SetSpacing(label.GetSpacing())
return new_image,new_seg
def Get_Resame(image_path, mask_path):
image = sitk.ReadImage(image_path)
mask = sitk.ReadImage(mask_path)
# 获取原始图像和mask的像素分辨率
reference_size = image.GetSize()
original_spacing = image.GetSpacing()
mask_spacing = mask.GetSpacing()
# 将图像和mask重新采样到1mm x 1mm x 1mm的像素分辨率
new_spacing = (1.5, 1.5, 5.0)
resampled_image = sitk.Resample(image, [int(sz*spc/ns) for sz,spc,ns in zip(image.GetSize(), original_spacing, new_spacing)], sitk.Transform(), sitk.sitkLinear, image.GetOrigin(), new_spacing, image.GetDirection())
resampled_mask = sitk.Resample(mask, [int(sz*spc/ns) for sz,spc,ns in zip(mask.GetSize(), mask_spacing, new_spacing)], sitk.Transform(), sitk.sitkNearestNeighbor, mask.GetOrigin(), new_spacing, mask.GetDirection())
reference_size = resampled_image.GetSize()
# 保存新的nii.gz文件和对应的mask文件
return resampled_image, resampled_mask
def preprocess_images_and_masks(images_folder, masks_folder, save_images_folder, save_masks_folder):
if not os.path.exists(save_images_folder):
os.makedirs(save_images_folder)
if not os.path.exists(save_masks_folder):
os.makedirs(save_masks_folder)
images_files = os.listdir(masks_folder)
for image_file in images_files:
# image_file = '119.nii.gz'
image_path = os.path.join(images_folder, image_file)
mask_path = os.path.join(masks_folder, image_file)
save_image_path = os.path.join(save_images_folder, image_file)
save_mask_path = os.path.join(save_masks_folder, image_file)
resampled_image, resampled_mask = Get_Resame(image_path, mask_path)
cropped_image, cropped_mask = Croped(resampled_image, resampled_mask)
resize_mask(cropped_image, cropped_mask,save_image_path, save_mask_path)
# 设置文件夹路径
images_folder = "D:/Download/data/data/images"
masks_folder = "D:/Download/masks"
save_images_folder = "D:/Download/data/data/image_pre"
save_masks_folder = "D:/Download/data/data/mask_pre"
# 执行预处理操作
preprocess_images_and_masks(images_folder, masks_folder, save_images_folder, save_masks_folder)