import os
import torch
from torch.utils.data import Dataset
import SimpleITK as sitk
from torch.utils.data import DataLoader
import random
import numpy as np
from scipy import ndimage
"""
1. 重采样 -> spacing统一
2. 窗宽窗位调整
3. 归一化到[0,1]
4. 随机裁剪, 超过边界时进行调整
transform: 翻转、旋转
"""
class CustomNiiDataset(Dataset):
"""自定义读取nii数据集 \n
args: \n
images_dir: 数据所在的文件夹 \n
labels_dir: 标签所在的文件夹 \n
hu_min: 窗范围最小值 \n
hu_max: 窗范围最大值 \n
new_spacing: 目标spacing \n
"""
def __init__(self,
num_classes,
images_dir,
labels_dir,
hu_min,
hu_max,
new_spacing=[1.0,1.0,1.0],
input_size=[64,64,64],
mode="train",
**kwargs):
super(CustomNiiDataset,self).__init__(**kwargs)
self.num_classes = num_classes
self.images_dir = images_dir
self.labels_dir = labels_dir
self.hu_min = hu_min
self.hu_max = hu_max
self.new_spacing = new_spacing
self.input_size = input_size
self.mode = mode
self.filenames = os.listdir(labels_dir)
def __len__(self):
return len(self.filenames)
def __getitem__(self,idx):
image_path = os.path.join(self.images_dir,self.filenames[idx])
label_path = os.path.join(self.labels_dir,self.filenames[idx])
# 读取图像
image = sitk.ReadImage(image_path)
label = sitk.ReadImage(label_path)
# 重采样
image, label = self.resample(image,label)
assert image.GetSize() == label.GetSize(), f"error: image.size != label.size !"
# 窗宽窗位调整
image = self.window_intensity(image,self.hu_min, self.hu_max)
# [0,1]归一化
image_array = self.normalize(image)
label_array = sitk.GetArrayFromImage(label)
# 随机裁剪
img, lab = self.random_crop(image_array, label_array)
# 随机变换
if self.mode == "train":
# 随机旋转
if random.randint(0,1) == 0:
img, lab = self.random_rotate(img, lab)
# 随机翻转
if random.randint(0,1) == 0:
img, lab = self.random_flip(img, lab)
# 增加图像的channel
# -> 1, d, h, w
img_c = img[np.newaxis,...].astype("float32")
# 标签 onehot
lab_c = np.zeros(shape=[self.num_classes]+list(lab.shape),dtype="float32")
for i in range(self.num_classes):
tmp = np.zeros_like(lab)
tmp[lab==i] = 1
lab_c[i,...] = tmp
# -> [1,d,h,w], [num_classes, d, h, w]
return img_c, lab_c
def resample(self, itk_image, itk_label):
"""重采样到设置的spacing中 \n
args:
itk_image: sitk读取的image \n
itk_label: sitk读取的label \n
return: \n
重采样之后的itk_image, itk_label \n
"""
original_spacing = itk_image.GetSpacing()
original_size = itk_image.GetSize()
out_size = [
round(original_size[0]*original_spacing[0] / self.new_spacing[0]),
round(original_size[1]*original_spacing[1] / self.new_spacing[1]),
round(original_size[2]*original_spacing[2] / self.new_spacing[2])
]
resampler = sitk.ResampleImageFilter()
resampler.SetOutputSpacing(self.new_spacing)
resampler.SetSize(out_size)
resampler.SetOutputDirection(itk_image.GetDirection())
resampler.SetOutputOrigin(itk_image.GetOrigin())
resampler.SetTransform(sitk.Transform())
resampler.SetDefaultPixelValue(itk_image.GetPixelIDValue())
resampler.SetInterpolator(sitk.sitkNearestNeighbor)
return resampler.Execute(itk_image), resampler.Execute(itk_label)
def window_intensity(self, itk_image, hu_min, hu_max):
"""窗宽窗位调整 \n
args: \n
itk_image: simple itk 读取的图像 \n
hu_min: 窗范围最小值 \n
hu_max: 窗范围最大值 \n
return: 调整窗宽窗位后的图像 \n
"""
ww_filter = sitk.IntensityWindowingImageFilter()
ww_filter.SetWindowMinimum(hu_min)
ww_filter.SetWindowMaximum(hu_max)
ww_filter.SetOutputMinimum(hu_min)
ww_filter.SetOutputMaximum(hu_max)
return ww_filter.Execute(itk_image)
def normalize(self,itk_image):
"""根据itk图像本身的像素范围进行[0,1]归一化 \n
args:
itk_image: simpleitk 图像 \n
return : 归一化后的图像
"""
image_array = sitk.GetArrayFromImage(itk_image)
value_range = self.hu_max - self.hu_min
image_array = (image_array - self.hu_min) * 1.0 / value_range
return image_array
def random_crop(self, image_array, label_array):
assert image_array.shape == label_array.shape, f"error, image_array.shape != label_array.shape !"
D,H,W = image_array.shape
d,h,w = self.input_size
crop_failure = True
while crop_failure:
# depth
id = random.randint(0,D-1)
d_start = id - d//2
if d_start < 0:
continue
d_end = d_start + d # [d_start: d_end]
if d_end > D:
continue
# height
ih = random.randint(0,H-1)
h_start = ih - h//2
if h_start < 0:
continue
h_end = h_start + h # [h_start: h_end]
if h_end > H:
continue
# width
iw = random.randint(0,W-1)
w_start = iw - w//2
if w_start < 0:
continue
w_end = w_start + w # [h_start: h_end]
if w_end > W:
continue
img = image_array[d_start:d_end, h_start:h_end, w_start:w_end]
lab = label_array[d_start:d_end, h_start:h_end, w_start:w_end]
return img, lab
def random_rotate(self, img, lab):
"""随机旋转3维数组 \n
args:
img: 图像数组 \n
lab: 标签数组 \n
return: \n
旋转后的 img, lab \n
"""
rotate_angle = random.randint(0, 360)
img = ndimage.rotate(img, rotate_angle, axes=[1,2], reshape=False, mode="nearest", order=0)
lab = ndimage.rotate(lab, rotate_angle, axes=[1,2], reshape=False, mode="nearest", order=0)
return img, lab
def random_flip(self, img, lab):
if random.randint(1,2) == 1:
img = np.flip(img, axis=1)
lab = np.flip(lab, axis=1)
if random.randint(1,2) == 2:
img = np.flip(img, axis=2)
lab = np.flip(lab, axis=2)
return img, lab
if __name__ == '__main__':
train_dataset = CustomNiiDataset(
num_classes=2,
images_dir="G:/blood_vessel2023/images/train",
labels_dir="G:/blood_vessel2023/lung/train",
hu_min = -1000,
hu_max = 600,
mode="train"
)
# x, y = train_dataset[12]
# print(x.shape, y.shape)
train_loader = DataLoader(train_dataset,batch_size=2,shuffle=True)
for img, lab in train_loader:
print(img.shape, lab.shape)
torch实现3d医学图像的dataloader,并进行数据增强
最新推荐文章于 2023-05-15 15:15:42 发布