import torch
from torch import nn
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name
from nnunetv2.paths import nnUNet_results
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
from nnunetv2.experiment_planning.plan_and_preprocess_api import plan_experiments
import json
def load_json(file: str):
with open(file, 'r') as f:
a = json.load(f)
return a
from nnunetv2.run.run_training import run_training
import os
def build_network_architecture(plans_manager: PlansManager,
dataset_json,
configuration_manager: ConfigurationManager,
num_input_channels,
enable_deep_supervision: bool = True) -> nn.Module:
return get_network_from_plans(plans_manager, dataset_json, configuration_manager,
num_input_channels, deep_supervision=enable_deep_supervision)
if __name__ == "__main__":
dataset_id =511
task_name = "roi"
checkpoint_name = 'checkpoint_final.pth'
configuration = '3d_fullres'
fold = '0'
patch_size=[96, 128, 192]
dataset_name = convert_id_to_dataset_name(dataset_id)
model_training_output_dir = os.path.join(nnUNet_results, dataset_name)
checkpoint_file = os.path.join(nnUNet_results, dataset_name,f'fold_{fold}', checkpoint_name)
checkpoint = torch.load(checkpoint_file,map_location=torch.device('cpu'))
dataset_json_file = os.path.join(model_training_output_dir, 'dataset.json')
dataset_fingerprint_file = os.path.join(model_training_output_dir, 'dataset_fingerprint.json')
plan_file = os.path.join(model_training_output_dir, 'plans.json')
dataset_json = load_json(dataset_json_file)
dataset_fingerprint = load_json(dataset_fingerprint_file)
plan = load_json(plan_file)
f = int(fold) if fold != 'all' else fold
plans_manager = PlansManager(plan)
configuration_manager = plans_manager.get_configuration(configuration)
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager,
dataset_json)
network =build_network_architecture(
plans_manager,
dataset_json,
configuration_manager,
num_input_channels
)
net=network
net.load_state_dict(checkpoint["network_weights"])
net.eval()
dummy_input = torch.randn(1, 1, *patch_size)#.to("cuda")
torch.onnx.export(
net,
dummy_input,
os.path.join(nnUNet_results, dataset_name,f'fold_{fold}', f'{task_name}.onnx'),
input_names=['input'],
output_names=['output']
# dynamic_axes = {'input': {0: 'batch_size'},'output': {0: 'batch_size'}}
)
参考博客:https://blog.csdn.net/qq_35007834/article/details/134707439?spm=1001.2014.3001.5501
按照大佬的模版改的,能生成会有警告,欢迎评论交流!
补充:如果运行不成功,可直接修改predict_from_raw_data.py文件,推理一下即可。
def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:
"""
IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON
TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE!
RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE.
SEE convert_predicted_logits_to_segmentation_with_correct_shape
"""
n_threads = torch.get_num_threads()
torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads)
with torch.no_grad():
prediction = None
for params in self.list_of_parameters:
# messing with state dict names...
if not isinstance(self.network, OptimizedModule):
self.network.load_state_dict(params)
else:
self.network._orig_mod.load_state_dict(params)
# ##################
# # # 转onnx
# print("---------------转onnx-------------------------")
# self.network=self.network.cpu()
# self.network.eval()
# x1 = torch.randn(1,1,96,96,128, requires_grad=True).cpu()
# torch.onnx.export(
# self.network,
# x1,
# r'C:/nnUNet/nnUNetFrame/DATASET/nnUNet_results/test/tooth_seg.onnx',
# export_params=True,
# verbose=True,
# do_constant_folding=True,
# opset_version=11,
# input_names=['modelInput'],
# output_names=['modelOutput']#,
# #dynamic_axes={'modelinput': {0:'batchsize',2: 'height',3: 'width'}, 'modeloutput': {0:'batchsize',2: 'height',3: 'width'}}
# )
# print("Model has been converted to onnx!")
# why not leave prediction on device if perform_everything_on_device? Because this may cause the
# second iteration to crash due to OOM. Grabbing tha twith try except cause way more bloated code than
# this actually saves computation time
if prediction is None:
prediction = self.predict_sliding_window_return_logits(data).to('cpu')
else:
prediction += self.predict_sliding_window_return_logits(data).to('cpu')
if len(self.list_of_parameters) > 1:
prediction /= len(self.list_of_parameters)
if self.verbose: print('Prediction done')
prediction = prediction.to('cpu')
torch.set_num_threads(n_threads)
return prediction