MONAI中三种Dataset(Dataset/CacheDataset/SmartCacheDataset),含示例代码

最近在做3D医学影像分割,边实验边记录,如有错误还请指正。
代码基于3D分割任务

1,预处理

ROI_size = (128, 128, 128)
num_samples = 12
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        # 标准化
        ScaleIntensityRanged(   
            keys=["image"],
            a_min=-57,
            a_max=700,
            b_min=0.0,
            b_max=1.0,
            clip=True,),
        ...
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=700,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        ...
    ]
)

2,数据加载

data_dir = r'/home/xxx/' # 根目录

train_images = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]

# 余下12例作为验证
train_files, val_files = data_dicts[:-12], data_dicts[-12:]

1,Dataset

class monai.data.Dataset(datatransform=None)

参数

data:典型的输入数据为字典列表:

[{                             {                             {
     'img': 'image1.nii.gz',       'img': 'image2.nii.gz',       'img': 'image3.nii.gz',
     'seg': 'label1.nii.gz',       'seg': 'label2.nii.gz',       'seg': 'label3.nii.gz',
     'extra': 123                  'extra': 456                  'extra': 789
 },                            },                            }]

transform:如前文定义

示例代码

train_ds = Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=0)

val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=0)

2,CacheDataset

class monai.data.CacheDataset(datatransform=Nonecache_num=9223372036854775807cache_rate=1.0num_workers=1progress=Truecopy_cache=Trueas_contiguous=Truehash_as_key=Falsehash_func=<function pickle_hashing>runtime_cache=False)

具有缓存机制的Dataset,在第一个epoch之前,将所有数据,在做第一次随机变换处理前加载进缓存。

例如transform中,依次使用 LoadImaged,EnsureChannelFirstd,,ScaleIntensityRanged,Spacingd,RandRotate90d,RandFlipd,RandShiftIntensityd...

按顺序第一个随机性变换为RandRotate90d,所以将经过LoadImaged,EnsureChannelFirstd,ScaleIntensityRanged,Spacingd这些变换之后的数据载入缓存,剩余的随机变换(RandRotate90d,RandFlipd,RandShiftIntensityd...)以及后面的确定性变换均在迭代中进行。

所以在使用CacheDataset时,尽量将确定性变换编写在transform开头处。相比Dataset可以节省大量的训练时间

参数:

cache_num – 要缓存的数据总数。默认值为 sys.maxsize

cache_rate – 缓存数据总数的百分比,默认值为 1.0(全部缓存)

实际缓存数据总数将采用cache_num, data_length x cache_rate, data_length三者的最小值

num_workers – 在初始化时计算缓存时的工作线程数。

progress – 是否显示进度条

copy_cache – 是否在应用随机转换之前深度复制缓存内容, 默认值为 True。如果随机转换不修改缓存的内容,或者每个缓存项在多处理环境中只使用一次, 可以设置 copy=False 以获得更好的性能

示例代码

train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=1, shuffle=True)

val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False)

3,SmartCacheDataset

class monai.data.SmartCacheDataset(datatransform=Nonereplace_rate=0.1cache_num=9223372036854775807cache_rate=1.0num_init_workers=1num_replace_workers=1progress=Trueshuffle=Trueseed=0copy_cache=Trueas_contiguous=Trueruntime_cache=False)

加载数据较多时,使用CacheDataset可能导致内存溢出

在SmartCacheDataset方法中,一个epoch训练完成后,智能缓存将使用相同数量的项之前的训练数据。例如,如果我们有5张图片:[image1, image2, image3, image4, image5] ,并设置参数cache_num=4,replace_rate=0.25。 则每个 epoch 缓存和替换的实际训练图像如下所示:

epoch 1: [image1, image2, image3, image4]
epoch 2: [image2, image3, image4, image5]
epoch 3: [image3, image4, image5, image1]
epoch 3: [image4, image5, image1, image2]
epoch N: [image[N % 5] ...]

示例代码

train_ds = SmartCacheDataset(data=train_files, transform=train_transforms,         
                             replace_rate=0.7, cache_num=32,  num_init_workers=4, 
                             num_replace_workers=4)
train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=1, shuffle=True)

数据需要在每个epoch(除第一个epoch)之前更新

train_ds.update_cache()

train_ds.start()
for epoch in range(max_epochs):
    print("-" * 12)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0

    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device))

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        # print(f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss:{loss.item():.4f}")
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    train_ds.update_cache()

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),)

                roi_size = ROI_size
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            metric = dice_metric.aggregate().item()
            dice_metric.reset()

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(data_dir, "xx.pth"))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\n32 best mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}"
            )
train_ds.shutdown()

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值