Pytorch模型转onnx模型(多输入)

参考

【1】https://zhuanlan.zhihu.com/p/41255090 理论
【2】https://zhuanlan.zhihu.com/p/272767300 理论
【3】https://zhuanlan.zhihu.com/p/273566106 实战
【4】https://zhuanlan.zhihu.com/p/286298001 实战
【5】https://www.jianshu.com/p/65cfb475584a 实战
【6】https://pytorch.apachecn.org/docs/1.4/31.html 实战
【7】https://heroinlin.github.io/2018/08/15/Pytorch/Pytorch_export_onnx/ 踩坑记录
【8】https://blog.csdn.net/SilentOB/article/details/102863944 理论
【9】https://github.com/onnx/onnx/blob/master/docs/PythonAPIOverview.md#polishing-the-model 实战
【10】https://blog.csdn.net/ZM_Yang/article/details/103977679 onnxruntime源码解析
【11】https://pytorch.apachecn.org/docs/1.0/onnx.html 记录了ONNX支持的运算符

简介

Open Neural Network Exchange(ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移。
目前官方支持加载ONNX模型并进行推理的深度学习框架有: Caffe2, PyTorch, MXNet,ML.NET,TensorRT 和 Microsoft CNTK,并且 TensorFlow 也非官方的支持ONNX。

ONNX是一个开放式规范,由以下组件组成:
1.可扩展计算图模型的定义
2.标准数据类型的定义
3.内置运算符的定义

实验环境

Win10 + python3.6.12 + pytorch1.7.1 + onnx1.7.0 +
onnxruntime 1.6.0

Pytorch2ONNX

import torch
import onnx
network = Network() #加载网络
network.eval() #不进行梯度传递
dummy_input = torch.randn(1,3,224,224) # 网络输入大小
torch.onnx.export(model, 
                  dummy_input, 
                  "mynetwork.onnx",  
                  export_params=True,  # 是否保存模型的训练好的参数
                  verbose=True, # 是否输出debug描述
                  input_names=['input1', 'input2'], # 定义输入结点的名字,有几个输入就定义几个
                  output_names=['output'], # 定义输出结点的名字
                  opset_version=11, #onnx opset的库版本
                  do_constant_folding=True, # whether do constant-folding optimization 该优化会替换全为常数输入的分支为最终结果
                  operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK )
                  # 选择用那种ops.有三种选择:
                  # 1. OperatorExportTypes.ONNX(默认),导出为标准ONNX运算符
                  # 2. OperatorExportTypes.ONNX_ATEN, 导出为aten运算符
                  # 3. OperatorExportTypes.ONNX_ATEN_FALLBACK, 在ONNX中注册的运算符使用ONNX,其余使用aten
                 	Example graph::
                 	graph(%0 : Float)::
                  	%3 : int = prim::Constant[value=0]()
                  	%4 : Float = aten::triu(%0, %3) # missing op
                  	%5 : Float = aten::mul(%4, %0) # registered op
                  	return (%5)
            	 	is exported as::
                 	graph(%0 : Float)::
                  	%1 : Long() = onnx::Constant[value={0}]()
                  	%2 : Float = aten::ATen[operator="triu"](%0, %1)  # missing op
                  	%3 : Float = onnx::Mul(%2, %0) # registered op
                  	return (%3)

ONNX推理

model = onnx.load('mynetwork.onnx')
#检查IR是否良好
onnx.checker.check_model(model)
#输出一个图形的可读表示方式
onnx.helper.printable_graph(model.graph)
# 以上部分在我的实验中基本没用,能运行的模型和不能运行的模型都一样输出

import onnxruntime
onnx_session = onnxruntime.InferenceSession('mynetwork.onnx')
inputs = (torch.randn(1,3,224,224),  torch.randn(1,3,224,224))
output_name = onnx_session.get_outputs()[0].name
input_name1 = onnx_session.get_inputs()[0].name
input_name2 = onnx_session.get_inputs()[1].name

def to_numpy(tensor):
    return tensor.detach().cpu().numpy().astype(np.float32) if tensor.requires_grad else tensor.cpu().numpy().astype(np.float32)
res = onnx_session.run([output_name], {input_name1:to_numpy(inputs[0]), input_name2:to_numpy(inputs[1])})[0]
#这里run函数的输入格式需要格外注意,否则会报错:
#""" run(self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: #onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object] """
#输入是1.List[output_name] 也就是一个**列表**,记录了输出的node的名字,有几个输出就写几个名字
#     2.Dict[str, input]是一个**字典**,记录了输入Node名字以及他们的值,有几个输入就写几个key-value对。

#除此之外,run的输入必须是numpy.narray类型,不能是tensor型。数据类型一般是np.float32。

常见坑

  1. onnx只能输出静态图,因此不支持if-else分支。一次只能走一个分支。如果代码中有if-else语句,需要改写。
  2. onnx不支持步长为2的切片。例如a[::2,::2]
  3. onnx不支持对切片对象赋值。例如a[0,:,:,:]=b, 可以用torch.cat改写
  4. onnx里面的resize要求output shape必须为常量。可以用以下代码解决:
if isinstance(size, torch.Size):
    size = tuple(int(x) for x in size)
  1. 参考7中提到,pytorch导出onnx不支持adaptiveAvgPool2d, Expand, ReLU6.
  2. 不支持自定义的conv,支持conv1d,conv2d,conv3d。
  3. C++调用时数据类型要是32位浮点型!!!

附录

【1】一个好用的神经网络图示工具。输入onnx,tensorflow等模型,可以显示模型的流程图。
https://gitee.com/yousuos/onnx-simplifier
【2】一个简化模型流程图的工具
https://github.com/daquexian/onnx-simplifier
【3】ONNX动态输入
https://www.codenong.com/cs107117759/

  • 11
    点赞
  • 48
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值