什么是ONNX:ONNX(Open Neural Network Exchange)是深度学习模型的一种保存格式,由微软、亚马逊、Facebook和IBM等公司共同开发开放式的文件格式,用于存储训练好的模型、简化深度学习模型的部署和迁移。
先贴个微软的官网教程:将 PyTorch 模型转换为 ONNX 格式 | Microsoft Learn
下面提供示例
1.先安装onnx包,在需要使用python环境下 pip install onnx
如果import onnx显示错误信息:“DLL load failed while importing onnx_cpp2py_export: 动态链接库(DLL)初始化例程失败。”,请将onnx降级为1.16.1或调整为其他版本。
2.判断文件类型来加载模型参数,调用官方提供的函数转换为onnx。
import os
import torch
from torch import nn
import onnx
# # # Model For Example
# 定义一个简单的PyTorch模型,作为示例
class ExampleModel(nn.Module):
def __init__(self):
super(ExampleModel, self).__init__()
self.ln = nn.Linear(10,10)
def forward(self, x):
x = torch.softmax(self.ln(x),dim=1)
return x
def transpose_onnx(model, dummy_input, model_path, save_path):
# 判断目标文件是否存在
if os.path.isfile(save_path):
print(f"The save file exists.")
return
# 将模型和变量转到CPU
# set the model to cpu
device = "cpu"
model = model.to(device)
dummy_input = dummy_input.to(device)
# set the model to inference mode
model.eval()
# 根据模型文件类型加载权重
mdl_type = model_path.split(".")
print(f"using {mdl_type[-1]} type file")
if mdl_type[-1] == "bin":
print(f"using {mdl_type[-1]} type")
checkpoint = torch.load(model_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
elif mdl_type[-1] == "pth":
print(f"using {mdl_type[-1]} type")
checkpoint = torch.load(model_path, map_location='cpu')
if isinstance(checkpoint,nn.Module):
# if use torch.save(model, 'model.pth')) while saving model
model = checkpoint
else:
# if use torch.save(model.state_dict(), 'model.pth') while saving model
model.load_state_dict(checkpoint)
else:
print(f"using {mdl_type[-1]} type")
model = torch.load(model_path, map_location='cpu')
# 调用库函数
torch.onnx.export(model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
save_path, # where to save the model
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
print(" ")
print('Model has been converted to ONNX')
return
if __name__ == "__main__":
# 1、训练好的模型文件(.pth或.pt)的存放路径model_path; 即将生成的模型文件(.onnx)的存放路径save_path
model_path = '.\\example_model.pth'
save_path = '.\\example_model.onnx'
# 2、声明模型以及伪模型输入变量
model = ExampleModel()
dummy_input = torch.rand(1, 10)
# 3、调用转换函数
transpose_onnx(model=model, dummy_input=dummy_input, model_path=model_path,save_path=save_path)