【nnunetv2.pth模型转onnx】

import torch
from torch import nn
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name
from nnunetv2.paths import nnUNet_results
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
from nnunetv2.experiment_planning.plan_and_preprocess_api import plan_experiments
import json
def load_json(file: str):
    with open(file, 'r') as f:
        a = json.load(f)
    return a
from nnunetv2.run.run_training import run_training
import os

def build_network_architecture(plans_manager: PlansManager,
                                   dataset_json,
                                   configuration_manager: ConfigurationManager,
                                   num_input_channels,
                                   enable_deep_supervision: bool = True) -> nn.Module:
       
        return get_network_from_plans(plans_manager, dataset_json, configuration_manager,
                                      num_input_channels, deep_supervision=enable_deep_supervision)

if __name__ == "__main__":
 
    dataset_id =511
    task_name = "roi"
    checkpoint_name = 'checkpoint_final.pth'
    configuration = '3d_fullres'
    fold = '0'
    patch_size=[96, 128, 192]
 
    dataset_name = convert_id_to_dataset_name(dataset_id)
    model_training_output_dir = os.path.join(nnUNet_results, dataset_name)
    checkpoint_file = os.path.join(nnUNet_results, dataset_name,f'fold_{fold}', checkpoint_name)
    checkpoint = torch.load(checkpoint_file,map_location=torch.device('cpu'))
    
    dataset_json_file = os.path.join(model_training_output_dir, 'dataset.json')
    dataset_fingerprint_file = os.path.join(model_training_output_dir, 'dataset_fingerprint.json')
    plan_file = os.path.join(model_training_output_dir, 'plans.json')
     
    dataset_json = load_json(dataset_json_file)
    dataset_fingerprint = load_json(dataset_fingerprint_file)
    plan = load_json(plan_file)

                         
    f = int(fold) if fold != 'all' else fold

    plans_manager = PlansManager(plan)

    configuration_manager = plans_manager.get_configuration(configuration)
    num_input_channels = determine_num_input_channels(plans_manager, configuration_manager,
                                                                   dataset_json)
    network =build_network_architecture(
                plans_manager,
                dataset_json,
                configuration_manager,
                num_input_channels
            )

 
    net=network
    net.load_state_dict(checkpoint["network_weights"])
    net.eval()
 
    dummy_input = torch.randn(1, 1, *patch_size)#.to("cuda")
 
    torch.onnx.export(
        net,
        dummy_input,
        os.path.join(nnUNet_results, dataset_name,f'fold_{fold}', f'{task_name}.onnx'),
        input_names=['input'],
        output_names=['output']
        # dynamic_axes = {'input': {0: 'batch_size'},'output': {0: 'batch_size'}}
        )



参考博客:https://blog.csdn.net/qq_35007834/article/details/134707439?spm=1001.2014.3001.5501
按照大佬的模版改的,能生成会有警告,欢迎评论交流!

补充:如果运行不成功,可直接修改predict_from_raw_data.py文件,推理一下即可。

def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:
        """
        IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON
        TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE!

        RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE.
        SEE convert_predicted_logits_to_segmentation_with_correct_shape
        """
        n_threads = torch.get_num_threads()
        torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads)
        with torch.no_grad():
            prediction = None

            for params in self.list_of_parameters:

                # messing with state dict names...
                if not isinstance(self.network, OptimizedModule):
                    self.network.load_state_dict(params)
                else:
                    self.network._orig_mod.load_state_dict(params)

                # ##################
                # # # 转onnx
                # print("---------------转onnx-------------------------")
                # self.network=self.network.cpu()
                # self.network.eval()
                # x1 = torch.randn(1,1,96,96,128, requires_grad=True).cpu()
                # torch.onnx.export(
                #     self.network,
                #     x1,
                #     r'C:/nnUNet/nnUNetFrame/DATASET/nnUNet_results/test/tooth_seg.onnx',
                #     export_params=True,
                #     verbose=True,
                #     do_constant_folding=True,
                #     opset_version=11,
                #     input_names=['modelInput'],
                #     output_names=['modelOutput']#,
                #     #dynamic_axes={'modelinput': {0:'batchsize',2: 'height',3: 'width'}, 'modeloutput': {0:'batchsize',2: 'height',3: 'width'}}
                # )
                # print("Model has been converted to onnx!")


                # why not leave prediction on device if perform_everything_on_device? Because this may cause the
                # second iteration to crash due to OOM. Grabbing tha twith try except cause way more bloated code than
                # this actually saves computation time
                if prediction is None:
                    prediction = self.predict_sliding_window_return_logits(data).to('cpu')
                else:
                    prediction += self.predict_sliding_window_return_logits(data).to('cpu')

            if len(self.list_of_parameters) > 1:
                prediction /= len(self.list_of_parameters)

            if self.verbose: print('Prediction done')
            prediction = prediction.to('cpu')
        torch.set_num_threads(n_threads)
        return prediction
  • 5
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值