nnunet v2版本的模型转化为onnx
转化为onnx后可以转化为engine文件,方便在c++使用;(可以移步tensorRT分类中看)
import torch
from nnunetv2.inference.predict_from_raw_data import load_what_we_need
from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name
from nnunetv2.paths import nnUNet_results
from nnunetv2.experiment_planning.plan_and_preprocess_api import plan_experiments
import json
def load_json(file: str):
with open(file, 'r') as f:
a = json.load(f)
return a
from preprocess.run_training import run_training
import os
if __name__ == "__main__":
dataset_id =9
task_name = "vessel"
checkpoint_name = 'state_dict.pth'
configuration = '3d_fullres'
fold = '1'
patch_size=[90,160,160]
dataset_name = convert_id_to_dataset_name(dataset_id)
model_training_output_dir = os.path.join(nnUNet_results, dataset_name)
checkpoint = torch.load(os.path.join(nnUNet_results, dataset_name,f'fold_{fold}', checkpoint_name),map_location=torch.device('cpu'))
dataset_json_file = os.path.join(model_training_output_dir, 'dataset.json')
dataset_fingerprint_file = os.path.join(model_training_output_dir, 'dataset_fingerprint.json')
plan_file = os.path.join(model_training_output_dir, 'plan.json')
dataset_json = load_json(dataset_json_file)
dataset_fingerprint = load_json(dataset_fingerprint_file)
plan = plan_experiments(dataset_id, dataset_json, dataset_fingerprint, gpu_memory_target_in_gb=8,overwrite_plans_name='nnUNetPlans')
parameters, configuration_manager, inference_allowed_mirroring_axes, \
plans_manager, network, trainer_name = \
load_what_we_need(model_training_output_dir, dataset_id, configuration, fold, checkpoint_name, dataset_json, plan)
f = int(fold) if fold != 'all' else fold
nnunet_trainer = run_training(dataset_id, configuration, f, plan, dataset_json)
if not nnunet_trainer.was_initialized:
nnunet_trainer.initialize()
net=nnunet_trainer.network
net.load_state_dict(checkpoint["models_state_dict"][0])
net.eval()
dummy_input = torch.randn(1, 1, *patch_size)#.to("cuda")
torch.onnx.export(
net,
dummy_input,
os.path.join(nnUNet_results, dataset_name,f'fold_{fold}', f'{task_name}.onnx'),
input_names=['input'],
output_names=['output'],
dynamic_axes = {'input': {0: 'batch_size'},'output': {0: 'batch_size'}}
)