最近在做3D医学影像分割,边实验边记录,如有错误还请指正。
代码基于3D分割任务
1,预处理
ROI_size = (128, 128, 128)
num_samples = 12
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
# 标准化
ScaleIntensityRanged(
keys=["image"],
a_min=-57,
a_max=700,
b_min=0.0,
b_max=1.0,
clip=True,),
...
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
ScaleIntensityRanged(
keys=["image"],
a_min=-57,
a_max=700,
b_min=0.0,
b_max=1.0,
clip=True,
),
...
]
)
2,数据加载
data_dir = r'/home/xxx/' # 根目录
train_images = sorted(glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]
# 余下12例作为验证
train_files, val_files = data_dicts[:-12], data_dicts[-12:]
1,Dataset
class monai.data.Dataset(data, transform=None)
参数
data:典型的输入数据为字典列表:
[{ { {
'img': 'image1.nii.gz', 'img': 'image2.nii.gz', 'img': 'image3.nii.gz',
'seg': 'label1.nii.gz', 'seg': 'label2.nii.gz', 'seg': 'label3.nii.gz',
'extra': 123 'extra': 456 'extra': 789
}, }, }]
transform:如前文定义
示例代码
train_ds = Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=0)
val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=0)
2,CacheDataset
class monai.data.CacheDataset(data, transform=None, cache_num=9223372036854775807, cache_rate=1.0, num_workers=1, progress=True, copy_cache=True, as_contiguous=True, hash_as_key=False, hash_func=<function pickle_hashing>, runtime_cache=False)
具有缓存机制的Dataset,在第一个epoch之前,将所有数据,在做第一次随机变换处理前加载进缓存。
例如transform中,依次使用 LoadImaged,EnsureChannelFirstd,,ScaleIntensityRanged,Spacingd,RandRotate90d,RandFlipd,RandShiftIntensityd...
按顺序第一个随机性变换为RandRotate90d,所以将经过LoadImaged,EnsureChannelFirstd,ScaleIntensityRanged,Spacingd这些变换之后的数据载入缓存,剩余的随机变换(RandRotate90d,RandFlipd,RandShiftIntensityd...)以及后面的确定性变换均在迭代中进行。
所以在使用CacheDataset时,尽量将确定性变换编写在transform开头处。相比Dataset可以节省大量的训练时间
参数:
cache_num – 要缓存的数据总数。默认值为 sys.maxsize
cache_rate – 缓存数据总数的百分比,默认值为 1.0(全部缓存)
实际缓存数据总数将采用cache_num, data_length x cache_rate, data_length三者的最小值
num_workers – 在初始化时计算缓存时的工作线程数。
progress – 是否显示进度条
copy_cache – 是否在应用随机转换之前深度复制缓存内容, 默认值为 True。如果随机转换不修改缓存的内容,或者每个缓存项在多处理环境中只使用一次, 可以设置 copy=False 以获得更好的性能
示例代码
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=1, shuffle=True)
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False)
3,SmartCacheDataset
class monai.data.SmartCacheDataset(data, transform=None, replace_rate=0.1, cache_num=9223372036854775807, cache_rate=1.0, num_init_workers=1, num_replace_workers=1, progress=True, shuffle=True, seed=0, copy_cache=True, as_contiguous=True, runtime_cache=False)
加载数据较多时,使用CacheDataset可能导致内存溢出
在SmartCacheDataset方法中,一个epoch训练完成后,智能缓存将使用相同数量的项之前的训练数据。例如,如果我们有5张图片:[image1, image2, image3, image4, image5] ,并设置参数cache_num=4,replace_rate=0.25。 则每个 epoch 缓存和替换的实际训练图像如下所示:
epoch 1: [image1, image2, image3, image4]
epoch 2: [image2, image3, image4, image5]
epoch 3: [image3, image4, image5, image1]
epoch 3: [image4, image5, image1, image2]
epoch N: [image[N % 5] ...]
示例代码
train_ds = SmartCacheDataset(data=train_files, transform=train_transforms,
replace_rate=0.7, cache_num=32, num_init_workers=4,
num_replace_workers=4)
train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=1, shuffle=True)
数据需要在每个epoch(除第一个epoch)之前更新
train_ds.update_cache()
train_ds.start()
for epoch in range(max_epochs):
print("-" * 12)
print(f"epoch {epoch + 1}/{max_epochs}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = (
batch_data["image"].to(device),
batch_data["label"].to(device))
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# print(f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss:{loss.item():.4f}")
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
train_ds.update_cache()
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
for val_data in val_loader:
val_inputs, val_labels = (
val_data["image"].to(device),
val_data["label"].to(device),)
roi_size = ROI_size
sw_batch_size = 4
val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
val_labels = [post_label(i) for i in decollate_batch(val_labels)]
# compute metric for current iteration
dice_metric(y_pred=val_outputs, y=val_labels)
metric = dice_metric.aggregate().item()
dice_metric.reset()
metric_values.append(metric)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), os.path.join(data_dir, "xx.pth"))
print("saved new best metric model")
print(
f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
f"\n32 best mean dice: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}"
)
train_ds.shutdown()