torch实现3d医学图像的dataloader,并进行数据增强

import os
import torch
from torch.utils.data import Dataset
import SimpleITK as sitk
from torch.utils.data import DataLoader
import random
import numpy as np
from scipy import ndimage

"""
1. 重采样 -> spacing统一
2. 窗宽窗位调整 
3. 归一化到[0,1]
4. 随机裁剪, 超过边界时进行调整

transform: 翻转、旋转

"""

class CustomNiiDataset(Dataset):
    """自定义读取nii数据集 \n
    args: \n
        images_dir: 数据所在的文件夹 \n
        labels_dir: 标签所在的文件夹 \n
        hu_min: 窗范围最小值 \n
        hu_max: 窗范围最大值 \n
        new_spacing: 目标spacing \n
    """
    def __init__(self,
                    num_classes,
                    images_dir,
                    labels_dir,
                    hu_min,
                    hu_max,
                    new_spacing=[1.0,1.0,1.0],
                    input_size=[64,64,64],
                    mode="train",
                    **kwargs):
        super(CustomNiiDataset,self).__init__(**kwargs)
        self.num_classes = num_classes
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.hu_min = hu_min
        self.hu_max = hu_max
        self.new_spacing = new_spacing
        self.input_size = input_size
        self.mode = mode

        self.filenames = os.listdir(labels_dir)

        
    def __len__(self):
        return len(self.filenames)

    def __getitem__(self,idx):
        image_path = os.path.join(self.images_dir,self.filenames[idx])
        label_path = os.path.join(self.labels_dir,self.filenames[idx])

        # 读取图像
        image = sitk.ReadImage(image_path)
        label = sitk.ReadImage(label_path)

        # 重采样
        image, label = self.resample(image,label)
        assert image.GetSize() == label.GetSize(), f"error: image.size != label.size !"

        # 窗宽窗位调整
        image = self.window_intensity(image,self.hu_min, self.hu_max)

        # [0,1]归一化
        image_array = self.normalize(image)
        label_array = sitk.GetArrayFromImage(label)

        # 随机裁剪
        img, lab = self.random_crop(image_array, label_array)
        
        # 随机变换
        if self.mode == "train":
            # 随机旋转
            if random.randint(0,1) == 0:
                img, lab = self.random_rotate(img, lab)

            # 随机翻转
            if random.randint(0,1) == 0:
                img, lab = self.random_flip(img, lab)

        # 增加图像的channel
        # -> 1, d, h, w
        img_c = img[np.newaxis,...].astype("float32")
        
        # 标签 onehot
        lab_c = np.zeros(shape=[self.num_classes]+list(lab.shape),dtype="float32")
        for i in range(self.num_classes):
            tmp = np.zeros_like(lab)
            tmp[lab==i] = 1
            lab_c[i,...] = tmp

        # -> [1,d,h,w], [num_classes, d, h, w]
        return img_c, lab_c

    
    def resample(self, itk_image, itk_label):
        """重采样到设置的spacing中 \n
        args:
            itk_image: sitk读取的image \n
            itk_label: sitk读取的label \n
        return: \n
            重采样之后的itk_image, itk_label \n
        """
        original_spacing = itk_image.GetSpacing()
        original_size = itk_image.GetSize()

        out_size = [
            round(original_size[0]*original_spacing[0] / self.new_spacing[0]),
            round(original_size[1]*original_spacing[1] / self.new_spacing[1]),
            round(original_size[2]*original_spacing[2] / self.new_spacing[2])
        ]

        resampler = sitk.ResampleImageFilter()
        resampler.SetOutputSpacing(self.new_spacing)
        resampler.SetSize(out_size)
        resampler.SetOutputDirection(itk_image.GetDirection())
        resampler.SetOutputOrigin(itk_image.GetOrigin())
        resampler.SetTransform(sitk.Transform())
        resampler.SetDefaultPixelValue(itk_image.GetPixelIDValue())
        resampler.SetInterpolator(sitk.sitkNearestNeighbor)

        return resampler.Execute(itk_image), resampler.Execute(itk_label)

    def window_intensity(self, itk_image, hu_min, hu_max):
        """窗宽窗位调整 \n
        args: \n
            itk_image: simple itk 读取的图像 \n
            hu_min: 窗范围最小值 \n
            hu_max: 窗范围最大值 \n
        return: 调整窗宽窗位后的图像 \n
        """
        ww_filter = sitk.IntensityWindowingImageFilter()

        ww_filter.SetWindowMinimum(hu_min)
        ww_filter.SetWindowMaximum(hu_max)
        ww_filter.SetOutputMinimum(hu_min)
        ww_filter.SetOutputMaximum(hu_max)

        return ww_filter.Execute(itk_image)

    def normalize(self,itk_image):
        """根据itk图像本身的像素范围进行[0,1]归一化 \n
        args:
            itk_image: simpleitk 图像 \n
        return : 归一化后的图像
        """
        image_array = sitk.GetArrayFromImage(itk_image)
        value_range = self.hu_max - self.hu_min

        image_array = (image_array - self.hu_min) * 1.0  / value_range
        return image_array

    def random_crop(self, image_array, label_array):
        assert image_array.shape == label_array.shape, f"error, image_array.shape != label_array.shape !"

        D,H,W = image_array.shape
        d,h,w = self.input_size

        crop_failure = True
        while crop_failure:
            # depth 
            id = random.randint(0,D-1)
            d_start = id - d//2
            if d_start < 0:
                continue
            d_end = d_start + d # [d_start: d_end]
            if d_end > D:
                continue

            # height
            ih = random.randint(0,H-1)
            h_start = ih - h//2
            if h_start < 0:
                continue
            h_end = h_start + h # [h_start: h_end]
            if h_end > H:
                continue
            
            # width
            iw = random.randint(0,W-1)
            w_start = iw - w//2
            if w_start < 0:
                continue
            w_end = w_start + w # [h_start: h_end]
            if w_end > W:
                continue
            
            img = image_array[d_start:d_end, h_start:h_end, w_start:w_end]
            lab = label_array[d_start:d_end, h_start:h_end, w_start:w_end]
            return img, lab

    def random_rotate(self, img, lab):
        """随机旋转3维数组 \n
        args:
            img: 图像数组 \n
            lab: 标签数组 \n
        return: \n
            旋转后的 img, lab \n
        """
        rotate_angle = random.randint(0, 360)
        img = ndimage.rotate(img, rotate_angle, axes=[1,2], reshape=False, mode="nearest", order=0)
        lab = ndimage.rotate(lab, rotate_angle, axes=[1,2], reshape=False, mode="nearest", order=0)

        return img, lab

    def random_flip(self, img, lab):
        if random.randint(1,2) == 1:
            img = np.flip(img, axis=1)
            lab = np.flip(lab, axis=1)

        if random.randint(1,2) == 2:
            img = np.flip(img, axis=2)
            lab = np.flip(lab, axis=2)
        
        return img, lab


if __name__ == '__main__':
    train_dataset = CustomNiiDataset(
        num_classes=2,
        images_dir="G:/blood_vessel2023/images/train",
        labels_dir="G:/blood_vessel2023/lung/train",
        hu_min = -1000,
        hu_max = 600,
        mode="train"
    )
    
    # x, y = train_dataset[12]
    # print(x.shape, y.shape)

    train_loader = DataLoader(train_dataset,batch_size=2,shuffle=True)
    
    for img, lab in train_loader:
        print(img.shape, lab.shape)
    
  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
加载音频文件并将其转换为图像文件需要经过以下几个步骤: 1. 使用音频处理库(如librosa)加载音频文件并读取其数据。 2. 对音频数据进行预处理,例如进行STFT(短时傅里叶变换)将音频数据转换为频谱图。 3. 将预处理后的数据保存为图像文件。 下面是一个示例代码,演示如何使用torch.utils.data.dataloader加载音频文件并将其转换为图像文件: ```python import torch import librosa import numpy as np from PIL import Image from torch.utils.data import DataLoader, Dataset class AudioImageDataset(Dataset): def __init__(self, audio_files): self.audio_files = audio_files def __len__(self): return len(self.audio_files) def __getitem__(self, idx): # Load audio file audio, sr = librosa.load(self.audio_files[idx], sr=16000) # Preprocess audio data spec = np.abs(librosa.stft(audio, hop_length=512, n_fft=2048)) # Convert to image img = Image.fromarray(spec) # Return image tensor return torch.from_numpy(np.array(img)).unsqueeze(0) # Test the dataset dataset = AudioImageDataset(['audio1.wav', 'audio2.wav']) dataloader = DataLoader(dataset, batch_size=1) for i, batch in enumerate(dataloader): print(batch.shape) # (1, 1025, 32) ``` 在这个示例中,我们使用librosa库来加载音频文件并读取其数据。然后,我们对音频数据进行预处理,使用短时傅里叶变换将音频数据转换为频谱图。最后,我们将频谱图转换为图像并返回图像的PyTorch张量表示。最终,我们可以使用torch.utils.data.dataloader数据集加载到内存中,并可以对其进行批处理和其他操作。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值