torch转onnx模型
一、前言
onnx是开放神经网络交换格式,用于不同框架之间的迁移,推理方面比原生的torch快很多。本文以MobilenetV3做分类任务为例,实现模型转换。
二、使用步骤
1.torch转换为onnx
代码如下:
# torch2onnx.py
import torch
import torchvision
from models.mobilenetv3 import MobileNetV3_Large # 引入模型
torch.set_grad_enabled(False)
torch_model = MobileNetV3_Large(2) # 初始化网络
torch_model.load_state_dict(torch.load('./mobilenetv3-40-regular.pth'), False) # 加载训练好的pth模型
batch_size = 1 # 批处理大小
input_shape = (1, 128, 128) # 输入数据,我这里是灰度训练所以1代表是单通道,RGB训练是3,128是图像输入网络的尺寸
# set the model to inference mode
torch_model.eval().cpu() # cpu推理
x = torch.randn(batch_size, *input_shape).cpu() # 生成张量
export_onnx_file = "./mobilenetv3-40-regular.onnx" # 要生成的ONNX文件名
torch.onnx.export(torch_model,
x,
export_onnx_file,
opset_version=10,
do_constant_folding=True, # 是否执行常量折叠优化
input_names=["input"], # 输入名
output_names=["output"], # 输出名
dynamic_axes={"input": {0: "batch_size"}, # 批处理变量
"output": {0: "batch_size"}})
这样将会得到onnx的转化模型
2.onnx推理
代码如下:
import cv2
from PIL import Image
import onnxruntime as ort
import numpy as np
def softmax(x):
x = x.reshape(-1)
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=0)
def postprocess(result):
return softmax(np.array(result)).tolist()
if __name__ == "__main__":
onnx_model_path = "./mobilenetv3-40-regular.onnx" # onnx模型
ort_session = ort.InferenceSession(onnx_model_path)
# 输入层名字
onnx_input_name = ort_session.get_inputs()[0].name
# 输出层名字
onnx_outputs_names = ort_session.get_outputs()[0].name
img = Image.open('./srcImgOK.jpeg').convert("L") # 需要识别的图像读为灰度
img = img.resize((128, 128), 0) # resize成网络输入需要的size
img = np.asarray(img, np.float32)/255.0 # 归一化
img = img[np.newaxis, np.newaxis, :, :]
# 如果是RGB则
# img = img[np.newaxis, :, :, :]
input_blob = np.array(img, dtype=np.float32)
onnx_result = ort_session.run([onnx_outputs_names], input_feed={onnx_input_name: input_blob})
res = postprocess(onnx_result) # softmax
idx = np.argmax(res)
print(idx) # 打印识别结果
print(res[idx]) # 对应的概率
3、可能的异常
RuntimeError: Failed to export an ONNX attribute ‘onnx::Gather’, since it’s not constant, please try to make things (e.g., kernel size) static if possible
点开报错的这个symbolic_helper.py打印一下报错位置print(“>>>>>>”, v.node()),得到mobilenetv3.py:204:0,代表这个模型文件的204行转换有问题,具体问题具体分析,定位到问题就可以找一下解决方案
mobilenetv3为例执行转换代码时可以先打印一下这个out.size(),然后改成常量就可以正常转换
三.总结
转换后的onnx CPU推理速度大概快约10倍,有不对的望指正。