onnx的基本操作
最近在对模型进行量化时候,模型格式转变为onnx模型了,因此需要对onnx进行加载、运行以及量化(权重/输入/输出)。故,对onnx模型的相关操作进行简单的学习,并写下了这边博客,若有错误请指出,谢谢。
一、onnx的配置环境
onnx的环境主要包含两个包onnx和onnxruntime,我们可以通过pip安装这两个依赖包。
pip install onnxruntime
pip install onnx
二、获取onnx模型的输出层
import onnx
# 加载模型
model = onnx.load('onnx_model.onnx')
# 检查模型格式是否完整及正确
onnx.checker.check_model(model)
# 获取输出层,包含层名称、维度信息
output = self.model.graph.output
print(output)
三、获取中节点输出数据
onnx模型通常只能拿到最后输出节点的输出数据,若想拿到中间节点的输出数据,需要我们自己添加相应的输出节点信息;首先需要构建指定的节点(层名称、数据类型、维度信息);然后再通过insert的方式将节点插入到模型中。
import onnx
from onnx import helper
# 加载模型
model = onnx.load('onnx_model.onnx')
# 创建中间节点:层名称、数据类型、维度信息
prob_info = helper.make_tensor_value_info('layer1',onnx.TensorProto.FLOAT, [1, 3, 320, 280])
# 将构建完成的中间节点插入到模型中
model.graph.output.insert(0, prob_info)
# 保存新的模型
onnx.save(model, 'onnx_model_new.onnx')
# 扩展:
# 删除指定的节点方法: item为需要删除的节点
# model.graph.output.remove(item)
四、onnx前向InferenceSession的使用
关于onnx的前向推理,onnx使用了onnxruntime计算引擎。
onnx runtime是一个用于onnx模型的推理引擎。微软联合Facebook等在2017年搞了个深度学习以及机器学习模型的格式标准–ONNX,顺路提供了一个专门用于ONNX模型推理的引擎(onnxruntime)。
import onnxruntime
# 创建一个InferenceSession的实例,并将模型的地址传递给该实例
sess = onnxruntime.InferenceSession('onnxmodel.onnx')
# 调用实例sess的润方法进行推理
outputs = sess.run(output_layers_name, {input_layers_name: x})
1. 创建实例,源码分析
class InferenceSession(Session):
"""
This is the main class used to run a model.
"""
def __init__(self, path_or_bytes, sess_options=None, providers=[]):
"""
:param path_or_bytes: filename or serialized model in a byte string
:param sess_options: session options
:param providers: providers to use for session. If empty, will use
all available providers.
"""
self._path_or_bytes = path_or_bytes
self._sess_options = sess_options
self._load_model(providers)
self._enable_fallback = True
Session.__init__(self, self._sess)
def _load_model(self, providers=[]):
if isinstance(self._path_or_bytes, str):
self._sess = C.InferenceSession(
self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes,
True)
elif isinstance(self._path_or_bytes, bytes):
self._sess = C.InferenceSession(
self._sess_options if self._sess_options else C.get_default_session_options(), self._path_or_bytes,
False)
# elif isinstance(self._path_or_bytes, tuple):
# to remove, hidden trick
# self._sess.load_model_no_init(self._path_or_bytes[0], providers)
else:
raise TypeError("Unable to load from type '{0}'".format(type(self._path_or_bytes)))
self._sess.load_model(providers)
self._sess_options = self._sess.session_options
self._inputs_meta = self._sess.inputs_meta
self._outputs_meta = self._sess.outputs_meta
self._overridable_initializers = self._sess.overridable_initializers
self._model_meta = self._sess.model_meta
self._providers = self._sess.get_providers()
# Tensorrt can fall back to CUDA. All others fall back to CPU.
if 'TensorrtExecutionProvider' in C.get_available_providers():
self._fallback_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
else:
self._fallback_providers = ['CPUExecutionProvider']
在_load_model函数,可以发现在load模型的时候是通过C.InferenceSession,并且将相关的操作也委托给该类。从导入语句from onnxruntime.capi import _pybind_state as C可知其实就是一个c++实现的Python接口,其源码在onnxruntime\onnxruntime\python\onnxruntime_pybind_state.cc中。
2. 模型推理run,源码分析
def run(self, output_names, input_feed, run_options=None):
"""
Compute the predictions.
:param output_names: name of the outputs
:param input_feed: dictionary ``{ input_name: input_value }``
:param run_options: See :class:`onnxruntime.RunOptions`.
::
sess.run([output_name], {input_name: x})
"""
num_required_inputs = len(self._inputs_meta)
num_inputs = len(input_feed)
# the graph may have optional inputs used to override initializers. allow for that.
if num_inputs < num_required_inputs:
raise ValueError("Model requires {} inputs. Input Feed contains {}".format(num_required_inputs, num_inputs))
if not output_names:
output_names = [output.name for output in self._outputs_meta]
try:
return self._sess.run(output_names, input_feed, run_options)
except C.EPFail as err:
if self._enable_fallback:
print("EP Error: {} using {}".format(str(err), self._providers))
print("Falling back to {} and retrying.".format(self._fallback_providers))
self.set_providers(self._fallback_providers)
# Fallback only once.
self.disable_fallback()
return self._sess.run(output_names, input_feed, run_options)
else:
raise
在run函数中,数据的推理是通过调用self._sess.run来进行前向推理的。同理该函数的具体实现实在c++的InferenceSession类中实现的。
五、遇到的一些问题
- 输入数据维度或类型不正确
从上图可以看出,该模型的输入数据的维度信息为[1, 3, 480, 640],输入数据类型为float32;所以在构建输入数据时,一定要按照该信息去构建,否则代码将会报错。
注: python调用c++代码是通过pybind11实现。