import multiprocessing
import os
import socket
from typing import Union, Optional
import nnunetv2
import torch.cuda
import torch.distributed as dist
import torch.multiprocessing as mp
from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json
from nnunetv2.paths import nnUNet_preprocessed
from nnunetv2.run.load_pretrained_weights import load_pretrained_weights
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from torch.backends import cudnn
def find_free_network_port() -> int:
"""Finds a free port on localhost.
It is useful in single-node training when we don't want to connect to a real main node but have to set the
`MASTER_PORT` environment variable.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
port = s.getsockname()[1]
s.close()
return port
def get_trainer_from_args(dataset_name_or_id: Union[int, str],
configuration: str,
fold: int,
trainer_name: str = 'nnUNetTrainer',
plans_identifier: str = 'nnUNetPlans',
device: torch.device = torch.device('cuda')):
# load nnunet class and do sanity checks
nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
if nnunet_trainer is None:
raise RuntimeError(f'Could not find requested nnunet trainer {trainer_name} in '
f'nnunetv2.training.nnUNetTrainer ('
f'{join(nnunetv2.__path__[0], "training", "nnUNetTrainer")}). If it is located somewhere '
f'else, please move it there.')
assert issubclass(nnunet_trainer, nnUNetTrainer), 'The requested nnunet trainer class must inherit from ' \
'nnUNetTrainer'
# handle dataset input. If it's an ID we need to convert to int from string
if dataset_name_or_id.startswith('Dataset'):
pass
else:
try:
dataset_name_or_id = int(dataset_name_or_id)
except ValueError:
raise ValueError(f'dataset_name_or_id must either be an integer or a valid dataset name with the pattern '
f'DatasetXXX_YYY where XXX are the three(!) task ID digits. Your '
f'input: {dataset_name_or_id}')
# initialize nnunet trainer
preprocessed_dataset_folder_base = join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id))
plans_file = join(preprocessed_dataset_folder_base, plans_identifier + '.json')
plans = load_json(plans_file)
dataset_json = load_json(join(preprocessed_dataset_folder_base, 'dataset.json'))
nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold,
dataset_json=dataset_json, device=device)
return nnunet_trainer
def maybe_load_checkpoint(nnunet_trainer: nnUNetTrainer, continue_training: bool, validation_only: bool,
pretrained_weights_file: str = None):
if continue_training and pretrained_weights_file is not None:
raise RuntimeError('Cannot both continue a training AND load pretrained weights. Pretrained weights can only '
'be used at the beginning of the training.')
if continue_training:
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')
if not isfile(expected_checkpoint_file):
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_latest.pth')
# special case where --c is used to run a previously aborted validation
if not isfile(expected_checkpoint_file):
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_best.pth')
if not isfile(expected_checkpoint_file):
print(f"WARNING: Cannot continue training because there seems to be no checkpoint available to "
f"continue from. Starting a new training...")
expected_checkpoint_file = None
elif validation_only:
expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')
if not isfile(expected_checkpoint_file):
raise RuntimeError(f"Cannot run validation because the training is not finished yet!")
else:
if pretrained_weights_file is not None:
if not nnunet_trainer.was_initialized:
nnunet_trainer.initialize()
load_pretrained_weights(nnunet_trainer.network, pretrained_weights_file, verbose=True)
expected_checkpoint_file = None
if expected_checkpoint_file is not None:
nnunet_trainer.load_checkpoint(expected_checkpoint_file)
def setup_ddp(rank, world_size):
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup_ddp():
dist.destroy_process_group()
def run_ddp(rank, dataset_name_or_id, configuration, fold, tr, p, disable_checkpointing, c, val,
pretrained_weights, npz, val_with_best, world_size):
setup_ddp(rank, world_size)
torch.cuda.set_device(torch.device('cuda', dist.get_rank()))
nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, tr, p)
if disable_checkpointing:
nnunet_trainer.disable_checkpointing = disable_checkpointing
assert not (c and val), f'Cannot set --c and --val flag at the same time. Dummy.'
maybe_load_checkpoint(nnunet_trainer, c, val, pretrained_weights)
if torch.cuda.is_available():
cudnn.deterministic = False
cudnn.benchmark = True
if not val:
nnunet_trainer.run_training()
if val_with_best:
nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth'))
nnunet_trainer.perform_actual_validation(npz)
cleanup_ddp()
def run_training(dataset_name_or_id: Union[str, int],
configuration: str, fold: Union[int, str],
trainer_class_name: str = 'nnUNetTrainer',
plans_identifier: str = 'nnUNetPlans',
pretrained_weights: Optional[str] = None,
num_gpus: int = 1,
export_validation_probabilities: bool = False,
continue_training: bool = False,
only_run_validation: bool = False,
disable_checkpointing: bool = False,
val_with_best: bool = False,
device: torch.device = torch.device('cuda')):
if plans_identifier == 'nnUNetPlans':
print("\n############################\n"
"INFO: You are using the old nnU-Net default plans. We have updated our recommendations. "
"Please consider using those instead! "
"Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md"
"\n############################\n")
if isinstance(fold, str):
if fold != 'all':
try:
fold = int(fold)
except ValueError as e:
print(f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!')
raise e
if val_with_best:
assert not disable_checkpointing, '--val_best is not compatible with --disable_checkpointing'
if num_gpus > 1:
assert device.type == 'cuda', f"DDP training (triggered by num_gpus > 1) is only implemented for cuda devices. Your device: {device}"
os.environ['MASTER_ADDR'] = 'localhost'
if 'MASTER_PORT' not in os.environ.keys():
port = str(find_free_network_port())
print(f"using port {port}")
os.environ['MASTER_PORT'] = port # str(port)
mp.spawn(run_ddp,
args=(
dataset_name_or_id,
configuration,
fold,
trainer_class_name,
plans_identifier,
disable_checkpointing,
continue_training,
only_run_validation,
pretrained_weights,
export_validation_probabilities,
val_with_best,
num_gpus),
nprocs=num_gpus,
join=True)
else:
nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name,
plans_identifier, device=device)
if disable_checkpointing:
nnunet_trainer.disable_checkpointing = disable_checkpointing
assert not (continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'
maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)
if torch.cuda.is_available():
cudnn.deterministic = False
cudnn.benchmark = True
if not only_run_validation:
nnunet_trainer.run_training()
if val_with_best:
nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, 'checkpoint_best.pth'))
nnunet_trainer.perform_actual_validation(export_validation_probabilities)
def run_training_entry():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('dataset_name_or_id', type=str,
help="Dataset name or ID to train with")
parser.add_argument('configuration', type=str,
help="Configuration that should be trained")
parser.add_argument('fold', type=str,
help='Fold of the 5-fold cross-validation. Should be an int between 0 and 4.')
parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer',
help='[OPTIONAL] Use this flag to specify a custom trainer. Default: nnUNetTrainer')
parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',
help='[OPTIONAL] Use this flag to specify a custom plans identifier. Default: nnUNetPlans')
parser.add_argument('-pretrained_weights', type=str, required=False, default=None,
help='[OPTIONAL] path to nnU-Net checkpoint file to be used as pretrained model. Will only '
'be used when actually training. Beta. Use with caution.')
parser.add_argument('-num_gpus', type=int, default=1, required=False,
help='Specify the number of GPUs to use for training')
parser.add_argument('--npz', action='store_true', required=False,
help='[OPTIONAL] Save softmax predictions from final validation as npz files (in addition to predicted '
'segmentations). Needed for finding the best ensemble.')
parser.add_argument('--c', action='store_true', required=False,
help='[OPTIONAL] Continue training from latest checkpoint')
parser.add_argument('--val', action='store_true', required=False,
help='[OPTIONAL] Set this flag to only run the validation. Requires training to have finished.')
parser.add_argument('--val_best', action='store_true', required=False,
help='[OPTIONAL] If set, the validation will be performed with the checkpoint_best instead '
'of checkpoint_final. NOT COMPATIBLE with --disable_checkpointing! '
'WARNING: This will use the same \'validation\' folder as the regular validation '
'with no way of distinguishing the two!')
parser.add_argument('--disable_checkpointing', action='store_true', required=False,
help='[OPTIONAL] Set this flag to disable checkpointing. Ideal for testing things out and '
'you dont want to flood your hard drive with checkpoints.')
parser.add_argument('-device', type=str, default='cuda', required=False,
help="Use this to set the device the training should run with. Available options are 'cuda' "
"(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
"Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!")
args = parser.parse_args()
assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
if args.device == 'cpu':
# let's allow torch to use hella threads
torch.set_num_threads(multiprocessing.cpu_count())
device = torch.device('cpu')
elif args.device == 'cuda':
# multithreading in torch doesn't help nnU-Net if run on GPU
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
device = torch.device('cuda')
else:
device = torch.device('mps')
run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,
args.num_gpus, args.npz, args.c, args.val, args.disable_checkpointing, args.val_best,
device=device)
if __name__ == '__main__':
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
# reduces the number of threads used for compiling. More threads don't help and can cause problems
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = 1
# multiprocessing.set_start_method("spawn")
run_training_entry()代码如上,如何修改