import torch
model_path = './hitnet_sf_finalpass/version_40/checkpoints/epoch=6-step=44890.ckpt' left_in = (torch.randn(1, 3, 480, 640, device='cuda'), torch.randn(1, 3, 480, 640, device='cuda')) right_in = torch.randn(1, 3, 480, 640, device='cuda') ckpt = torch.load(model_path) model = PredictModel(**vars(args)).eval() model.load_state_dict(ckpt['state_dict']) # 给输入输出取个名字 input_names = ('input_1', 'input_2') output_names = ["output_1"] torch.onnx.export(model, left_in, "tinyhitnet.onnx", opset_version=13, #注意版本选择 verbose=True, input_names=input_names, output_names=output_names) print('export onnx model successful!')