torch转onnx模型

torch转onnx模型

一、前言

onnx是开放神经网络交换格式,用于不同框架之间的迁移,推理方面比原生的torch快很多。本文以MobilenetV3做分类任务为例,实现模型转换。


二、使用步骤

1.torch转换为onnx

代码如下:

#  torch2onnx.py
import torch
import torchvision
from models.mobilenetv3 import MobileNetV3_Large   # 引入模型

torch.set_grad_enabled(False)
torch_model = MobileNetV3_Large(2)  # 初始化网络
torch_model.load_state_dict(torch.load('./mobilenetv3-40-regular.pth'), False)  # 加载训练好的pth模型
batch_size = 1  # 批处理大小
input_shape = (1, 128, 128)  # 输入数据,我这里是灰度训练所以1代表是单通道,RGB训练是3128是图像输入网络的尺寸

# set the model to inference mode
torch_model.eval().cpu()  # cpu推理

x = torch.randn(batch_size, *input_shape).cpu()  # 生成张量
export_onnx_file = "./mobilenetv3-40-regular.onnx"  # 要生成的ONNX文件名
torch.onnx.export(torch_model,
                  x,
                  export_onnx_file,
                  opset_version=10,
                  do_constant_folding=True,  # 是否执行常量折叠优化
                  input_names=["input"],  # 输入名
                  output_names=["output"],  # 输出名
                  dynamic_axes={"input": {0: "batch_size"},  # 批处理变量
                                "output": {0: "batch_size"}})

这样将会得到onnx的转化模型

2.onnx推理

代码如下:

import cv2
from PIL import Image
import onnxruntime as ort
import numpy as np


def softmax(x):
    x = x.reshape(-1)
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)


def postprocess(result):
    return softmax(np.array(result)).tolist()


if __name__ == "__main__":
    onnx_model_path = "./mobilenetv3-40-regular.onnx"  # onnx模型
    ort_session = ort.InferenceSession(onnx_model_path)
    # 输入层名字
    onnx_input_name = ort_session.get_inputs()[0].name
    # 输出层名字
    onnx_outputs_names = ort_session.get_outputs()[0].name
    img = Image.open('./srcImgOK.jpeg').convert("L")  # 需要识别的图像读为灰度
    img = img.resize((128, 128), 0)  # resize成网络输入需要的size
    img = np.asarray(img, np.float32)/255.0  # 归一化
    img = img[np.newaxis, np.newaxis, :, :]
    # 如果是RGB则
    # img = img[np.newaxis, :, :, :]
    input_blob = np.array(img, dtype=np.float32)
    onnx_result = ort_session.run([onnx_outputs_names], input_feed={onnx_input_name: input_blob})
    res = postprocess(onnx_result)  # softmax
    idx = np.argmax(res)
    print(idx)  # 打印识别结果
    print(res[idx])  # 对应的概率

3、可能的异常

RuntimeError: Failed to export an ONNX attribute ‘onnx::Gather’, since it’s not constant, please try to make things (e.g., kernel size) static if possible
在这里插入图片描述点开报错的这个symbolic_helper.py打印一下报错位置print(“>>>>>>”, v.node()),得到mobilenetv3.py:204:0,代表这个模型文件的204行转换有问题,具体问题具体分析,定位到问题就可以找一下解决方案
mobilenetv3为例执行转换代码时可以先打印一下这个out.size(),然后改成常量就可以正常转换在这里插入图片描述

三.总结

转换后的onnx CPU推理速度大概快约10倍,有不对的望指正。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序鱼鱼mj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值