nnunetv2 .pth 转 .onnx (Tag=2.0/2.1)

 nnunet v2版本的模型转化为onnx

转化为onnx后可以转化为engine文件,方便在c++使用;(可以移步tensorRT分类中看)

import torch
from nnunetv2.inference.predict_from_raw_data import load_what_we_need
from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name
from nnunetv2.paths import nnUNet_results
from nnunetv2.experiment_planning.plan_and_preprocess_api import plan_experiments
import json
def load_json(file: str):
    with open(file, 'r') as f:
        a = json.load(f)
    return a
from preprocess.run_training import run_training
import os

if __name__ == "__main__":

    dataset_id =9
    task_name = "vessel"
    checkpoint_name = 'state_dict.pth'
    configuration = '3d_fullres'
    fold = '1'
    patch_size=[90,160,160]

    dataset_name = convert_id_to_dataset_name(dataset_id)
    model_training_output_dir = os.path.join(nnUNet_results, dataset_name)

    checkpoint = torch.load(os.path.join(nnUNet_results, dataset_name,f'fold_{fold}', checkpoint_name),map_location=torch.device('cpu'))

    dataset_json_file = os.path.join(model_training_output_dir, 'dataset.json')
    dataset_fingerprint_file = os.path.join(model_training_output_dir, 'dataset_fingerprint.json')
    plan_file = os.path.join(model_training_output_dir, 'plan.json')
     
    dataset_json = load_json(dataset_json_file)
    dataset_fingerprint = load_json(dataset_fingerprint_file)
    plan = plan_experiments(dataset_id, dataset_json, dataset_fingerprint, gpu_memory_target_in_gb=8,overwrite_plans_name='nnUNetPlans')
    
    parameters, configuration_manager, inference_allowed_mirroring_axes, \
        plans_manager, network, trainer_name = \
        load_what_we_need(model_training_output_dir, dataset_id, configuration, fold, checkpoint_name, dataset_json, plan)
                         
    f = int(fold) if fold != 'all' else fold

    nnunet_trainer = run_training(dataset_id, configuration, f, plan, dataset_json)

    if not nnunet_trainer.was_initialized:
        nnunet_trainer.initialize()
    net=nnunet_trainer.network
    net.load_state_dict(checkpoint["models_state_dict"][0])
    net.eval()

    dummy_input = torch.randn(1, 1, *patch_size)#.to("cuda")

    torch.onnx.export(
        net,
        dummy_input,
        os.path.join(nnUNet_results, dataset_name,f'fold_{fold}', f'{task_name}.onnx'),
        input_names=['input'],
        output_names=['output'],
        dynamic_axes = {'input': {0: 'batch_size'},'output': {0: 'batch_size'}}
        )

  • 6
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 16
    评论
评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值