monai框架提供了两种加载数据的方式。官方更推荐CacheDataset,将预处理过得数据缓存起来,会更快。
先给出两种不同方法加载数据的方式:
data_root = '/workspace/data/medical/Task09_Spleen'
train_images = sorted(glob.glob(os.path.join(data_root, 'imagesTr', '*.nii.gz')))
train_labels = sorted(glob.glob(os.path.join(data_root, 'labelsTr', '*.nii.gz')))
data_dicts = [{'image': image_name, 'label': label_name}
for image_name, label_name in zip(train_images, train_labels)]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]
# Dataset
train_ds = Dataset(data=train_files, transform= transform)
val_ds = Dataset(data=val_files, transform=val_trans)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)
# CacheDataset
cache_train_ds = CacheDataset(data=train_files, transform=transform, cache_rate=1.0, num_workers=4)
cache_val_ds = CacheDataset(data=val_files, transform=val_trans, cache_rate=1.0, num_workers=4)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)
1.普通Dataset的运行机制(不全)
每次epoch的时候加载数据,并根据定义的transform对数据进行预处理。也就是说,每个epoch都要处理一次数据。这样,处理数据的时间就会很多。
2.CacheDataset的运行机制
同上述不一样的是,CacheDataset会在训练之前先把数据处理好,缓存起来。但它的处理会根据transform分为两部分。第一部分,将transform中非随机的处理(LoadNiftid,AddChanneld`, `Spacingd`, `Orientationd`, `ScaleIntensityRanged`这些就是每个数据都会执行的 )在训练之前处理完,缓存起来。第二部分,带随机的处理(RandCropByPosNegLabeld这个就是随机裁剪)需要在每次epoch的时候完成,因为每个数据处理的都不一样。
相应的英文描述(可在源代码库查看。)
Dataset with cache mechanism that can load data and cache deterministic transforms' result during training.
when `transforms` is used in a multi-epoch training pipeline, before the first training epoch, this dataset will cache the results up to ``ScaleIntensityRanged``, as all non-random transforms `LoadNiftid`, `AddChanneld`, `Spacingd`, `Orientationd`, `ScaleIntensityRanged` can be cached. During training, the dataset will load the cached results and run ``RandCropByPosNegLabeld`` and ``ToTensord``, as ``RandCropByPosNegLabeld`` is a randomized transform and the outcome not cached.
根据这个处理规则,官方就建议,在写transform的时候,把非随机的变换尽量写在前面,随机的变换写在后面。
transforms = Compose([
LoadNiftid(),
AddChanneld(),
Spacingd(),
Orientationd(),
ScaleIntensityRanged(),
RandCropByPosNegLabeld(), #只有这一个是随机的
ToTensord() # 通常放在最后,但不一定
])
3.总结
在使用带缓存机制的dataset的时候,会在训练前处理数据,因此加载数据的过程会很慢,但是训练的时候会比普通的快。官方测试的一个对比如下:
这个图的对比差距很明显,但我在使用的时候差距没有这么大。
解释cache_rate, num_workers的设置
Here we use CacheDataset to accelerate training and validation process, it's 10x faster than the regular Dataset.
To achieve best performance, set cache_rate=1.0 to cache all the data, if memory is not enough, set lower value.
Users can also set cache_num instead of cache_rate, will use the minimum value of the 2 settings.
And set num_workers to enable multi-threads during caching.
If want to to try the regular Dataset, just change to use the commented code below.
官方实验的例子:monai
另外有一个疑问:官方说在使用CacheDataset后,dataloader不需要多线程,num_workers=1都可以。但我不在dataloader中使用多线程同样也会很慢。就是数据加载的非常慢。不知道是因为我的随机变换处理的时间长还是我使用方法有问题。所以我同样会使用多线程
train_ds = monai.data.CacheDataset(
data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=8
)
# don't need many workers because already cached the data
train_loader = monai.data.DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=1, pin_memory=True)