pytorch 模型预实验完成,进一步部署很重要的一步,转存pth模型为ONNX
Reference:
1.Pytorch模型转onnx模型
2.onnx和mnn调用pytorch模型
3.Cannot export saved ScriptModule models to ONNX format #14869
4.Example: AlexNet from PyTorch to ONNX
5.python 调用onnxruntime 实现单输入多输出
# -*- coding: utf-8 -*-
"""
Project: TESTCODE
Creator: CHENRAN
Create time: 2022-06-14 14:20
IDE: PyCharm
Introduction:
"""
import torch
print(torch.__version__)
# pth模型转onnx模型
def pth_to_onnx(input, pth_path, onnx_path,):
model = torch.load(pth_path)
model.eval()
loaded_model = torch.jit.load(pth_path)
loaded_model_output = loaded_model(input.cuda())
# 指定模型的输入,以及onnx的输出路径
"""
model:# 正在运行的模型
input: # 模型输入(或用于多个输入的元组)
onnx_path: # 保存onnx模型格式的路径名字
verbose:# 是否打印网络
opset_version:# 导出模型的ONNX版本
input_names:# 模型的输入名称
output_names:# 模型的输出名称
example_outputs: 模型的输出示例
"""
output_names_list = ['prob', 'prob_logit', 'prob_map', 'height_prob', 'height_prob_logit',
'center_mask', 'visit_mask', 'center_idx', 'offset', 'edge_map',
'seg_map']
torch.onnx.export(model, input, onnx_path,
opset_version=11,
input_names=['input'],
export_params=True,
enable_onnx_checker=True,
verbose=True,
output_names=output_names_list,
example_outputs= loaded_model_output,
)
print("Exporting .pth model to onnx model")
print("Successful!!!")
def main():
example = torch.rand(1, 3, 320, 800)
folder_path = '/media/ubuntu/backup/CRData/Eigenlanes_redefine/Modeling/culane/output/train/weight/'
jit_pth_path = folder_path+'finalmodelalljit.pth'
onnx_path = folder_path+'model.onnx'
pth_to_onnx(input=example, pth_path=jit_pth_path, onnx_path=onnx_path)
if __name__ == '__main__':
main()
Result: