nnUNet代码学习整理

 深度学习小白最近正在入门医学图像分割相关领域,看了看nnUNet v1源码,在阅读过程中对代码的主线进行了整理,为了排版方便,博客中没有展示源码,大家可以在Docs中查看。

文章中内容为本人自己的理解,如果有问题,欢迎大家批评指正~

1. 数据集格式转换

指令:nnUNet_convert_decathlon_task -i /path/of/input/folder/

2. 数据预处理

指令:nnUNet_plan_and_preprocess -t 1 --verify_dataset_integrity

2.1 crop

crop(task_name, False, tf)

(1)创建数据集列表,create_lists_from_splitted_dataset, list_of_files=[ [0, 1, 2, 3, label], …….]

(2)将label放入croped / task00X/ gt_segmentations

(3)根据非零掩膜裁剪图片,放入

/home/xrd/code/nnUNet/nnUNetFrame/DATASET/nnUNet_raw/nnUNet_cropped_data/Task009_Spleen

2.2 描述数据集特征

(1)根据.json文件中的模态,是否有CT(ct),判断是否进行强度归一化

collect_intensityproperties = True if (("CT" in modalities) or ("ct" in modalities)) else False

(2)分析数据集

dataset_analyzer = DatasetAnalyzer(cropped_out_dir, overwrite=False, num_processes=tf) # this class creates the fingerprint

_ = dataset_analyzer.analyze_dataset(collect_intensityproperties)

(3)分析内容

analyze_dataset: dataset_properties['all_sizes'] = sizes dataset_properties['all_spacings'] = spacings dataset_properties['all_classes'] = all_classes dataset_properties['modalities'] = modalities # {idx: modality name} dataset_properties['intensityproperties'] = intensityproperties #(基于mask分析,平均值、最大最小值、绝对值、0.05-0.995百分位值) dataset_properties['size_reductions'] = size_reductions # {patient_id: size_reduction}

(4)保存,地址:裁剪后图片的文件夹中

save_pickle(dataset_properties,join(self.folder_with_cropped_data,"dataset_properties.pkl"))

(5)将dataset_properties.pkl文件和dataset.json文件拷贝到存放预处理后数据的文件夹中

2.3 用于UNet-3D网络的训练计划,planner_3D = class ExperimentPlanner3D_v21

exp_planner = planner_3d(cropped_out_dir, preprocessing_output_dir_this_task)

exp_planner.plan_experiment()

2.3.1 根据裁剪的体积判断是否使用掩膜归一化

def determine_whether_to_use_mask_for_norm

判断是否为CT图像,CT图像返回false

求数据集上图像裁剪失去体积的平均值,如果图片裁剪失去的体积大于原体积的25%,返回true

(采用非零掩膜对图片处理后,再进行归一化)

2.3.2 确定target_spacing,def get_target_spacing(self):

(1)一般选择各个轴的中位数

(2)判断各向异性,如果存在不同轴spacing比值大于3,target_size比值小于1/3,为异性

has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings))

has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes)

(3)异性时,该轴spacing取10%时的值

target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10)

若10%时spacing小于其他轴spacing,取

target_spacing_of_that_axis = max(max(other_spacings), target_spacing_of_that_axis) + 1e-5

2.3.3 根据spacing确定new_shape

new_shapes = [np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)]

2.3.4 根据spacing对数据进行转置,将spacing最大的轴放在第一位

target_spacing_transposed = np.array(target_spacing)[self.transpose_forward]

median_shape_transposed = np.array(median_shape)[self.transpose_forward]

2.3.5 generate configuration for 3d_fullres

确定每个stage的属性,

self.get_properties_for_stage(target_spacing_transpose,target_spacing_transposed,

median_shape_transposed,len(self.list_of_cropped_npz_files),num_modalities, len(all_classes) + 1))

class ExperimentPlanner3D_v21(ExperimentPlanner):

def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,num_modalities, num_classes):

(1)确定input_patch_size初始值

根据current_spacing,确定input_patch_size,使spacing最小的轴满足实际空间大于512mm。

将input_patch_size每一个轴与shape中值(new_median_shape)比较,取较小的值

(2)确定 卷积核、池化核、池化次数 初值

def get_pool_and_conv_props_poolLateV2(patch_size, min_feature_map_size, max_numpool, spacing)

a.确定每个轴的池化次数

num_pool_per_axis = get_network_numpool(patch_size, max_numpool, min_feature_map_size)

 

def get_network_numpool(patch_size, maxpool_cap=999, min_feature_map_size=4): network_numpool_per_axis = np.floor([np.log(i / min_feature_map_size) / np.log(2) for i in patch_size]).astype(int) network_numpool_per_axis = [min(i, maxpool_cap) for i in network_numpool_per_axis] return network_numpool_per_axis

b.计算每个轴在该池化层的卷积核,池化核

池化核:pool = [2 if num_pool_per_axis[i] + p >= net_numpool else 1 for i in range(dim)]

若该轴的池化次数num_pool_per_axis+该池化层p >= 所有轴的最大池化次数net_numpool,pool=2

卷积核:根据各个轴spacing与初始的最大spacing的比值,若>0.5,该轴设为3,否则为1

c.为bottleneck添加卷积核

net_conv_kernel_sizes.append([3] * dim)

d.计算patch_size必须能够整除的值must_be_divisible_by,并填充patch_size

must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis)

patch_size = pad_shape(patch_size, must_be_divisible_by)

(3)根据显存的限制,更新input_patch_size, 卷积核、池化核、池化次数

a. 比较当前所用的显存与限制显存,若大于,减小patchsize

b. 减小patchsize时,将当前的patchsize与图片的median_shape做比值,取比值最大的轴先开刀(为了保证图像的比例变化较小),

先计算减去must_be_divisible_by后,must_be_divisible_by的变化后的值shape_must_be_divisible_by_new,然后减去新的shape_must_be_divisible_by_new

c. 调用def get_pool_and_conv_props_poolLateV2,更新卷积核、池化核、池化次数

d. 再次计算显存,并比较,重复上述过程,直到满足显存限制。

(4)计算batch_size

a. 默认值为2根据显存的消耗,适当扩大batch_size

batch_size = int(np.floor(max(ref / here, 1) * batch_size))

b. 检查batch_size是否过大,应满足 batch_size小于数据总量的0.05,若大于,设为1

(5)判断是否需要进行2D数据增强处理

do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[ 0]) > self.anisotropy_threshold

(6)返回值

plan = { 'batch_size': batch_size, 'num_pool_per_axis': network_num_pool_per_axis, 'patch_size': input_patch_size, 'median_patient_size_in_voxels': new_median_shape, 'current_spacing': current_spacing, 'original_spacing': original_spacing, 'do_dummy_2D_data_aug': do_dummy_2D_data_aug, 'pool_op_kernel_sizes': pool_op_kernel_sizes, 'conv_kernel_sizes': conv_kernel_sizes, }

2.3.6 判断是否需要 3d_lowres

如果数据集中图片体素的数量median_shape大于4倍的patch_size体素architecture_input_voxels_here,

需要下采样

architecture_input_voxels_here = np.prod(self.plans_per_stage[-1]['patch_size'], dtype=np.int64) if np.prod(median_shape) / architecture_input_voxels_here < \ self.how_much_of_a_patient_must_the_network_see_at_stage0: more = False else: more = True

2.3.7 数据集下采样方法(用于3d_lowres)

(1) 循环,每次将spacing扩大为原来的1.01,计算新的num_voxels,patch_size,architecture_input_voxels_here,直到满足条件,跳出循环

(2) 若2倍下采图片的体素数量np.prod(new['median_patient_size_in_voxels']小于原体素数量,在训练计划中加入3d_lowres

(3) 训练计划重新排序(若有3d_lowres,为stage0),并转为字典

2.3.8 数据标准化方案 determine_normalization_scheme

根据数据集中图片模式,定制标准化方案,分为CT,noNorm,nonCT 三类

def determine_normalization_scheme(self): schemes = OrderedDict() modalities = self.dataset_properties['modalities'] num_modalities = len(list(modalities.keys())) for i in range(num_modalities): if modalities[i] == "CT" or modalities[i] == 'ct': schemes[i] = "CT" elif modalities[i] == 'noNorm': schemes[i] = "noNorm" else: schemes[i] = "nonCT" return schemes

2.3.9 输出保存plan

命名为nnUNetPlansv2.1_plans_3D.pkl,保存于预处理文件夹下

plans = {'num_stages': len(list(self.plans_per_stage.keys())), 'num_modalities': num_modalities, 'modalities': modalities, 'normalization_schemes': normalization_schemes, 'dataset_properties': self.dataset_properties, 'list_of_npz_files': self.list_of_cropped_npz_files, 'original_spacings': spacings, 'original_sizes': sizes, 'preprocessed_data_folder': self.preprocessed_output_folder, 'num_classes': len(all_classes), 'all_classes': all_classes, 'base_num_features': self.unet_base_num_features, 'use_mask_for_norm': use_nonzero_mask_for_normalization, 'keep_only_largest_region': only_keep_largest_connected_component, 'min_region_size_per_class': min_region_size_per_class, 'min_size_per_class': min_size_per_class, 'transpose_forward': self.transpose_forward, 'transpose_backward': self.transpose_backward, 'data_identifier': self.data_identifier, 'plans_per_stage': self.plans_per_stage, 'preprocessor_name': self.preprocessor_name, 'conv_per_stage': self.conv_per_stage, }

2.4 用于UNet-2D网络的训练计划

2.5 根据训练计划,进行预处理 UNet-3D

exp_planner.run_preprocessing(threads), threads = (tl, tf)

可设置是否进行预处理

2.5.1 从训练计划中导入相关参数

2.5.2 根据preprocessor_name寻找类的定义,并实例化预处理类

self.preprocessor_name = "GenericPreprocessor"

preprocessor = preprocessor_class(normalization_schemes, use_nonzero_mask_for_normalization,

self.transpose_forward,

intensityproperties)

class GenericPreprocessor(object):

路径:/home/xrd/code/nnUNet/nnunet/preprocessing/preprocessing.py

设置重采样低分辨率轴的阈值(RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD):3

self.resample_separate_z_anisotropy_threshold = RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD

self.resample_order_data = 3

self.resample_order_seg = 1

2.5.3 设置采样间隔target_spacings,进程数num_threads

2.5.4 预处理操作(根据stage循环)

preprocessor.run(target_spacings, self.folder_with_cropped_data, self.preprocessed_output_folder,

self.plans['data_identifier'], num_threads)

根据stage数量,循环:

(1) 输出路径:

output_folder_stage = os.path.join(output_folder, data_identifier + "_stage%d" % i)

(2) 采样间隔 spacing
(3)多进程处理裁剪后的数据, p.starmap(self._run_internal, all_args)

def _run_internal

a. 加载裁剪后的数据集(单个case),并根据self.transpose_forward转置

b. 重采样并归一化

data, seg, properties = self.resample_and_normalize(data, target_spacing,

properties, seg, force_separate_z)

1) 重采样,并打印采样前后的spacing, shape

data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing,

self.resample_order_data, self.resample_order_seg,

force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0,

separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold)

self.resample_order_data = 3 三次插值

self.resample_order_seg = 1 线性插值

0为最近邻插值

确定采样方式,对data和seg分别进行采样。调用函数时,设置data为三次插值,seg为线性插值,不强制对z轴进行分离,如果max_spacing > separate_z_anisotropy_threshold * min_spacing (per axis)时,对z轴进行最近邻插值。

判断是否对z轴分离采样,然后对data,seg分别采样,输出data_reshaped和seg_reshaped,输出仍为4维

seg_reshaped = resample_data_or_seg

data_reshaped = resample_data_or_seg

def resample_data_or_seg(data, new_shape, is_seg, axis=None, order=3, do_separate_z=False, order_z=0)

2)归一化

根据数据模态,选择归一化方法,并判断是否用非零掩膜处理

c. 随机采样

在分割数据中为每个类别随机采样一些位置,确保至少采样1%的体素, 将采样位置保存在case属性中

properties['class_locations'] = class_locs,用于后续的训练批次加载

d. 保存并输出

预处理图片,保存至相应stage输出路径

np.savez_compressed(os.path.join(output_folder_stage, "%s.npz" % case_identifier),

data=all_data.astype(np.float32))

处理后图片属性

with open(os.path.join(output_folder_stage, "%s.pkl" % case_identifier), 'wb') as f:

pickle.dump(properties, f)

代码 def _run_internal

def _run_internal(self, target_spacing, case_identifier, output_folder_stage, cropped_output_dir, force_separate_z, all_classes): 

重采样代码 def resample_data_or_seg

def resample_data_or_seg(data, new_shape, is_seg, axis=None, order=3, do_separate_z=False, order_z=0):

(4)结束多进程处理

3. 模型训练

3.1 总体流程

3.1.1 指令

nnUNet_train CONFIGURATION TRAINER_CLASS_NAME TASK_NAME_OR_ID FOLD --npz (additional options)

根据官方文件,3d_fullres,3d_lowres, 2d网络的训练器network_trainer为 nnUNetTrainerV2,级联3d_cascade_fullres为nnUNetTrainerV2CascadeFullRes(该指令必须完成3d_lowres训练才能进行),也可指定自己的训练器

parser.add_argument("network") parser.add_argument("network_trainer") parser.add_argument("task", help="can be task name or task id") parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'')

3.1.2 参数设置

plans_file, output_folder_name, dataset_directory, batch_dice, stage, \ trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)

代码 def get_default_configuration

def get_default_configuration(network, task, network_trainer, plans_identifier=default_plans_identifier, search_in=(nnunet.__path__[0], "training", "network_training"), base_module='nnunet.training.network_training'): 

3.1.3 训练流程(以nnUNetTrainerV2为例)

self.max_num_epochs = 1000

self.initial_lr = 1e-2

(1)判断trainer_class是否为指定类的子类

(2)实例化trainer_class,trainer = trainer_class

(3)根据是否仅验证,初始化trainer,trainer.initialize(not validation_only)

(4)根据是否仅验证,是否有预训练模型,设施trainer, 并进行训练,trainer.run_training()

(6)判断是否对验证集进行推理,默认进行推理,trainer.validate。

(7)若UNet设置为3d_lowres,判断是否对下一阶段进行预测,predict_next_stage

初始训练参数设置 class NetworkTrainer(object)

NetworkTrainer为nnUNetTrainerV2的父类

class NetworkTrainer(object):

def __init__(self, deterministic=True, fp16=False): 

3.2 主要函数(以nnUNetTrainerV2为例)

3.2.1 trainer.initialize(not validation_only)

(1)创建输出文件夹
(2)导入训练文件,保存至self.plans, 将plans中数据初始化至self
(3)数据增强设置 self.setup_DA_params()

def setup_DA_params

- we increase roation angle from [-15, 15] to [-30, 30]

- scale range is now (0.7, 1.4), was (0.85, 1.25)

- we don't do elastic deformation anymore

根据图像size最大轴与最小州的比例是否大于3,判断是否进行2D数据增强

(4)深度监督损失函数设置

根据所处的池化层设置损失函数的权重

class MultipleOutputLoss2(nn.Module):
self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {})

class DC_and_CE_loss(nn.Module): def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum", square_dice=False, weight_ce=1, weight_dice=1, log_dice=False, ignore_label=None): 

交叉熵损失+dice

交叉熵损失:CrossEntropyLoss()函数

Dice: 每一类别的权重相同

class RobustCrossEntropyLoss(nn.CrossEntropyLoss):

class SoftDiceLoss(nn.Module): 

(5)获得基本的训练和验证数据生成器 Dataloader

self.dl_tr, self.dl_val = self.get_basic_generators()

a. self.load_dataset() 加载数据

b. 按照5折交叉验证划分数据集,self.do_split()

def do_split, 将5折交叉验证的全部划分文件存在数据集文件夹中,按照训练设置folder输出数据

splits_file = join(self.dataset_directory, "splits_final.pkl")

输出 self.dataset_tr, self.dataset_val

(级联训练高分辨率阶段时stage1,将低分辨率过程生成的掩码,在pred_next_stage文件夹中,加入数据集中)

c. 实例化DataLoader3D对象

dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,

False, oversample_foreground_percent=self.oversample_foreground_percent,

pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')

self.basic_generator_patch_size 数据增强前patch_size大小

self.patch_size 输入网络的patch_size

batch中强制包含前景的比例:

oversample_foreground_percent = 0.33,每个batch中,有不少于0.33比例的patch必须包含前景,前景的location在数据预处理时随机取出(随机数种子为1234),保存在properties中,取batch时,再从location中随机取。

代码:

class DataLoader3D(SlimDataLoaderBase): def __init__(self, data, patch_size, final_patch_size, batch_size, has_prev_stage=False, oversample_foreground_percent=0.0, memmap_mode="r", pad_mode="edge", pad_kwargs_data=None, pad_sides=None): 

(6)获得图像增强变化后的训练和验证数据生成器

主要通过batchgenerator包实现数据增强。

self.tr_gen, self.val_gen = get_moreDA_augmentation(

self.dl_tr, self.dl_val,

self.data_aug_params[

'patch_size_for_spatialtransform'],

self.data_aug_params,

deep_supervision_scales=self.deep_supervision_scales,

pin_memory=self.pin_memory,

use_nondetMultiThreadedAugmenter=False

)

"data_aug_params":

"{'selected_data_channels': None,

'selected_seg_channels': [0],

'do_elastic': False, 不使用弹性形变

'elastic_deform_alpha': (0.0, 900.0), 弹性形变的幅度

'elastic_deform_sigma': (9.0, 13.0), 弹性形变的规模

'p_eldef': 0.2, 弹性形变的概率

'do_scaling': True, 进行缩放

'scale_range': (0.7, 1.4), 缩放的比例

'independent_scale_factor_for_each_axis': False, 每个轴不进行独立缩放

'p_independent_scale_per_axis': 1, 选择每个轴进行缩放的概率

'p_scale': 0.2, 缩放的概率

'do_rotation': True, 旋转

'rotation_x': (-0.5235987755982988, 0.5235987755982988), 旋转角度,范围内随机选

'rotation_y': (-0.5235987755982988, 0.5235987755982988),

'rotation_z': (-0.5235987755982988, 0.5235987755982988),

'rotation_p_per_axis': 1, 选择每个轴进行旋转的概率

'p_rot': 0.2, 旋转的概率

'random_crop': False, 不进行随机裁剪,中心裁剪

'random_crop_dist_to_border': None,

'do_gamma': True, 对图像进行gamma校正,图像不反转

'gamma_retain_stats': True, 数据将被转换为与增强(gamma)前相同的均值和标准差。

'gamma_range': (0.7, 1.5), 范围

'p_gamma': 0.3, gamma进行概率

'do_mirror': True, 进行镜像,进行镜像的概率为1,但每个轴镜像的概率为0.5

'mirror_axes': (0, 1, 2), 进行镜像的轴

'dummy_2D': False, 不进行2D的图像增强

 

添加高斯噪声、高斯模糊、调整图像亮度(亮度乘某个比例) tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5)) tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15))

'mask_was_used_for_normalization': OrderedDict([(0, True), (1, True), (2, True), (3, True)]),

将掩膜外的所有值设为0

 

将seg中的-1,0用0替代 tr_transforms.append(RemoveLabelTransform(-1, 0))

'border_mode_data': 'constant', 常数填充边界,border_cval_data=0, order_data=3、

(border_mode_seg="constant", border_cval_seg=-1,order_seg=1)

'all_segmentation_labels': None,

'move_last_seg_chanel_to_data': False, 不将分割结果转移到数据中

'cascade_do_cascade_augmentations': False, 不做级联数据增强(fullres时)

'cascade_random_binary_transform_p': 0.4,

'cascade_random_binary_transform_p_per_label': 1,

'cascade_random_binary_transform_size': (1, 8),

'cascade_remove_conn_comp_p': 0.2,

'cascade_remove_conn_comp_max_size_percent_threshold': 0.15,

'cascade_remove_conn_comp_fill_with_other_class_p': 0.0,

'do_additive_brightness': False, 不改变图像的亮度(初始亮度与随机值相加)

'additive_brightness_p_per_sample': 0.15, 每个样本改变亮度的概率

'additive_brightness_p_per_channel': 0.5, 每个通道改变亮度的概率

'additive_brightness_mu': 0.0, 添加的亮度平均值

'additive_brightness_sigma': 0.1, 添加的亮度方差

 

改变图像的对比度,模拟低分辨率图像,对反转图像进行gamma校正 tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15)) tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.25, ignore_axes=ignore_axes)) tr_transforms.append( GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=0.1)) # inverted gamma

'num_threads': 12,

'num_cached_per_thread': 2,

'patch_size_for_spatialtransform': array([128, 128, 128])}",

def get_moreDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True, regions=None, use_nondetMultiThreadedAugmenter: bool = False): 

(7)初始化UNet网络结构,生成UNet网络,并部署在GPU

def initialize_network(self): 

Generic_UNet

路径:nnUNet/nnunet/network_architecture/generic_UNet.py

class Generic_UNet(SegmentationNetwork): 

(8)初始化优化器

self.weight_decay = 3e-5

self.initial_lr = 1e-2

def initialize_optimizer_and_scheduler(self): assert self.network is not None, "self.initialize_network must be called first" self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, momentum=0.99, nesterov=True) self.lr_scheduler = None

3.2.2 trainer.run_training()

(1)更新学习率

self.maybe_update_lr(self.epoch)

self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9)

def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9): return initial_lr * (1 - epoch / max_epochs)**exponent

(2)强制使用深监督
(3)采用父类的函数进行训练

ret = super().run_training()

nnUNetTrainer类:保存调试文件,self.save_debug_information(),路径为模型输出文件夹,save_json(dct, join(self.output_folder, "debug.json")),复制训练计划文件到输出文件夹,调用NetworkTrainer类run_training

NetworkTrainer类,run_training:训练网络,并进行验证,当epoch大于1000或满足停止训练条件时,跳出循环,打印训练和验证损失,每一类别的dice,计算移动平均损失,画损失函数图,......并保存最后一次训练模型和最佳模型

def run_training(self): 

训练网络,并进行验证,当epoch大于1000或满足停止训练条件时,跳出循环,打印训练和验证损失,每一类别的dice,计算移动平均损失,画损失函数图,......并保存最后一次训练模型和最佳模型

3.2.3 trainer.validate

推理方式

推理前,强制关闭网络的深监督

采用滑动窗口进行推理,滑半窗,窗口大小为patch_size,采用高斯权重,将不同窗口推理结果融合

每个窗口进行预测时,采用镜像对数据进行增强,采用3d数据时,有8种镜像方式,对其预测结果进行平均,然后将平均值✖️高斯权重,输出,每个体素的输出是64次预测推理的平均值

进行后处理

在新版中,作者添加了后处理步骤。这个后处理步骤的目标是针对每个类别保留最大的连通组件,移除其余部分。这样做的目的是看是否能够改善预测结果。

若该类别的预测效果得到改善,将应用于后续的inference过程中,将结果存入输出文件夹的postprocessing.json文件中

3.2.4 predict_next_stage

仅在3d_lowres时执行该过程,为后续级联训练的高分辨率过程作准备

5折训练过程中,处理验证集的分割结果,首先将低分辨率的分割结果重采样至下一阶段的分辨率与图像大小,然后返回分割结果

predicted_new_shape = resample_data_or_seg(predicted, target_shape, False, order=interpolation_order, do_separate_z=force_separate_z, order_z=interpolation_order_z) seg_new_shape = predicted_new_shape.argmax(0)

将分割结果放入预测结果的pred_next_stage文件夹中,5折训练后,可获得全部数据集的低分辨率分割掩码,作为后续高分辨率分割的预处理文件。

/mnt/users/code/nnUNet/nnUNetFrame/DATASET/nnUNet_trained_models/nnUNet/3d_lowres/Task009_Spleen/nnUNetTrainerV2__nnUNetPlansv2.1/pred_next_stage

4. 预测推理+后处理

nnUNet_predict = nnunet.inference.predict_simple:main

根据输入指令,选择不同模型、参数进行推理,主要函数,

predict_from_folder(model_folder_name, input_folder, output_folder, folds, save_npz, num_threads_preprocessing,

num_threads_nifti_save, lowres_segmentations, part_id, num_parts, not disable_tta,

overwrite_existing=overwrite_existing, mode=mode, overwrite_all_in_gpu=all_in_gpu,

mixed_precision=not args.disable_mixed_precision,

step_size=step_size, checkpoint_name=args.chk)

  1. 创建输出文件夹。

  2. 复制 plans.pkl 文件到输出文件夹,并检查文件存在。

  3. 加载 plans.pkl 文件,获取预计的模态数量。

  4. 检查输入文件夹的完整性,返回案例 ID 列表。

  5. 生成输出文件名列表和输入文件路径列表。

  6. 如果提供了低分辨率分割文件夹路径,检查并生成对应文件路径列表。

  7. 根据不同的模式(normal, fast, fastest)调用相应的预测函数,并返回预测结果。默认为normal

predict_cases(mode为normal)

可选择不同的模式进行推理,以为例predict_cases(mode为normal)

  1. 加载训练器,预测次数(折数),trainer, params = load_model_and_checkpoint_files

  2. 数据预处理。preprocessing = preprocess_multithreaded

  3. 进行推理。softmax = trainer.predict_preprocessed_data_return_seg_and_softmax

  4. 若预测次数大于1,取多次预测结果的均值

  5. 进行后处理(可以不进行),加载postprocessing.json文件(该文件在推理验证数据集时生成,存放在模型文件所在的文件夹中),根据文件内容进行预处理

  • 7
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值