nnunetv2系列:torch转onnx

nnunetv2系列:torch转onnx

首先感谢https://blog.csdn.net/chen_niansan/article/details/141940273作者提供的代码,这里做了一些简单的调整和优化,原代码存在覆盖文件问题。需要注意的是batch_size最好设置为0,这样实际推理时没有图像数量限制。

代码示例


import json
from os.path import isdir, join
from pathlib import Path
from typing import Tuple, Union

import numpy as np
# import onnx
from onnx import (
    load,
    checker
)
# import onnxruntime
from onnxruntime import InferenceSession
# import torch
from torch import (
    rand,
)
from torch.onnx import export as torch_onnx_export
from batchgenerators.utilities.file_and_folder_operations import load_json

from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.utilities.dataset_name_id_conversion import (
    maybe_convert_to_dataset_name,
)
from nnunetv2.utilities.file_path_utilities import get_output_folder

def export_dataset_json(
    json_path,
    c, fold, batch_size, use_dynamic_axes, 
    config, dataset_name, 
    dataset_json, foreground_intensity_properties,
):
    with open(json_path, "w", encoding="utf-8") as f:
        config_dict = {
            "configuration": c,
            "fold": fold,
            "model_parameters": {
                "batch_size": batch_size
                if not use_dynamic_axes
                else "dynamic",
                "patch_size": config.patch_size,
                "spacing": config.spacing,
                "normalization_schemes": config.normalization_schemes,
                # These are mostly interesting for certification
                # uses, but they are also useful for debugging.
                # "UNet_class_name": config.UNet_class_name,
                # "UNet_base_num_features": config.UNet_base_num_features,
                # "unet_max_num_features": config.unet_max_num_features,
                # "conv_kernel_sizes": config.conv_kernel_sizes,
                # "pool_op_kernel_sizes": config.pool_op_kernel_sizes,
                # "num_pool_per_axis": config.num_pool_per_axis,
            },
            "dataset_parameters": {
                "dataset_name": dataset_name,
                "num_channels": len(
                    dataset_json["channel_names"].keys()
                ),
                "channels": {
                    k: {
                        "name": v,
                        # For when normalization is not Z-Score
                        "foreground_properties": foreground_intensity_properties[
                            k
                        ],
                    }
                    for k, v in dataset_json[
                        "channel_names"
                    ].items()
                },
                "num_classes": len(dataset_json["labels"].keys()),
                "class_names": {
                    v: k for k, v in dataset_json["labels"].items()
                },
            },
        }

        json.dump(
            config_dict,
            f,
            indent=4,
        )


def export_onnx_model(
    dataset_name_or_id: Union[int, str],
    output_dir: Path,
    configurations: Tuple[str] = (
        "2d",
    ),
    batch_size: int = 0,
    trainer: str = "nnUNetTrainer",
    plans_identifier: str = "nnUNetPlans",
    # folds: Tuple[Union[int, str], ...] = (0, 1, 2, 3, 4),
    folds: Tuple[Union[int, str], ...] = (0,),
    strict: bool = True,
    save_checkpoints: Tuple[str, ...] = ("checkpoint_final.pth",),
    output_names: tuple[str, ...] = None,
    verbose: bool = False,
) -> None:
    if not output_names:
        output_names = (
            f"{checkpoint[:-4]}_fold_{fold_index}.onnx"
            for fold_index, checkpoint in zip(folds, save_checkpoints)
        )

    if batch_size < 0:
        raise ValueError("batch_size must be non-negative")
    # 
    use_dynamic_axes = batch_size == 0

    dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
    for c in configurations:
        # print(f"Configuration {c}")
        trainer_output_dir = get_output_folder(
            dataset_name, trainer, plans_identifier, c
        )
        dataset_json = load_json(join(trainer_output_dir, "dataset.json"))

        # While we load in this file indirectly, we need the plans file to
        # determine the foreground intensity properties.
        plans = load_json(join(trainer_output_dir, "plans.json"))
        foreground_intensity_properties = plans[
            "foreground_intensity_properties_per_channel"
        ]

        if not isdir(trainer_output_dir):
            if strict:
                raise RuntimeError(
                    f"{dataset_name} is missing the trained model of configuration {c}"
                )
            else:
                print(f"Skipping configuration {c}, does not exist")
                continue

        predictor = nnUNetPredictor(
            # perform_everything_on_gpu=False,
            # device=torch.device("cpu"),
        )

        for checkpoint_name, output_name in zip(
            save_checkpoints, output_names
        ):
            predictor.initialize_from_trained_model_folder(
                model_training_output_dir=trainer_output_dir,
                use_folds=folds,
                checkpoint_name=checkpoint_name,
                # disable_compilation=True,
            )

            list_of_parameters = predictor.list_of_parameters
            network = predictor.network
            config = predictor.configuration_manager

            for fold, params in zip(folds, list_of_parameters):
                network.load_state_dict(params)

                network.eval()

                # curr_output_dir = output_dir / c / f"fold_{fold}"
                curr_output_dir = Path(output_dir)
                curr_output_dir.mkdir(parents=True, exist_ok=True)

                if use_dynamic_axes:
                    rand_input = rand((1, 1, *config.patch_size))
                    torch_output = network(rand_input)

                    torch_onnx_export(
                        network,
                        rand_input,
                        curr_output_dir / output_name,
                        export_params=True,
                        verbose=verbose,
                        input_names=["input"],
                        output_names=["output"],
                        dynamic_axes={
                            "input": {0: "batch_size"},
                            "output": {0: "batch_size"},
                        },
                    )
                else:
                    rand_input = rand(
                        (batch_size, 3, *config.patch_size)
                    )
                    torch_output = network(rand_input)

                    torch_onnx_export(
                        network,
                        rand_input,
                        curr_output_dir / output_name,
                        export_params=True,
                        verbose=verbose,
                        input_names=["input"],
                        output_names=["output"],
                    )

                onnx_model = load(curr_output_dir / output_name)
                checker.check_model(onnx_model)

                ort_session = InferenceSession(
                    curr_output_dir / output_name,
                    providers=["CPUExecutionProvider"],
                )
                ort_inputs = {
                    ort_session.get_inputs()[0].name: rand_input.numpy()
                }
                ort_outs = ort_session.run(None, ort_inputs)

                try:
                    np.testing.assert_allclose(
                        torch_output.detach().cpu().numpy(),
                        ort_outs[0],
                        rtol=1e-03,
                        atol=1e-05,
                        verbose=True,
                    )
                except AssertionError as e:
                    print("WARN: Differences found between torch and onnx:\n")
                    print(e)
                    print(
                        "\nExport will continue, but please verify that your pipeline matches the original."
                    )

                print(f"Exported {curr_output_dir / output_name}")
                export_dataset_json(
                    f"{curr_output_dir}/config_{fold}.json",
                    c, fold, batch_size, use_dynamic_axes, 
                    config, dataset_name, 
                    dataset_json, foreground_intensity_properties,
                )



if __name__ == "__main__":
    export_onnx_model(
        dataset_name_or_id="500",
        output_dir="onnx_model",
        configurations={"2d"},
        batch_size=0,
        folds=(1,),
        save_checkpoints=("checkpoint_final.pth", ), 
        output_names=None, # onnx文件名组成的tuple
    )

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值