1. Pytorch导出ONNX
torch.onnx.export函数实现了pytorch模型到onnx模型的导出,在pytorch1.11.0中,torch.onnx.export函数参数如下:
def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
input_names=None, output_names=None, aten=False, export_raw_ir=False,
operator_export_type=None, opset_version=None, _retain_param_name=True,
do_constant_folding=True, example_outputs=None, strip_doc_string=True,
dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None,
enable_onnx_checker=True, use_external_data_format=False):
参数比较多,但常用的有如下几个:
model: pytorch模型
args: 第一个参数model的输入数据,因为模型的输入可能不止一个,因此采用元组作为参数
f: 导出的onnx模型文件路径
export_params: 导出的onnx模型文件可以包含网络结构与权重参数,如果设置该参数为False,则导出的onnx模型文件只包含网络结构,因此,一般保持默认为True即可
verbose: 该参数如果指定为True,则在导出onnx的过程中会打印详细的导出过程信息
input_names: 为输入节点指定名称,因为输入节点可能多个,因此该参数是一个列表
output_names: 为输出节点指定名称,因为输出节点可能多个,因此该参数是一个列表
opset_version: 导出onnx时参考的onnx算子集版本
dynamic_axes: 指定输入输出的张量,哪些维度是动态的,通过用字典的形式进行指定,如果某个张量的某个维度被指定为字符串或者-1,则认为该张量的该维度是动态的,但是一般建议只对batch维度指定动态,这样可提高性能,具体的格式见下面的代码
如下代码,定义了一个包含卷积层、relu激活层的网络,将该网络导出onnx模型,设置了输入、输出的batch、height、width3个维度是动态的
import torch
import torch.nn as nn
import torch.onnx
import os
# 定义一个模型
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 3, padding=1)
self.relu = nn.ReLU()
self.conv.weight.data.fill_(1) # 权重被初始化为1
self.conv.bias.data.fill_(0) # 偏置被初始化为0
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
return x
model = Model()
dummy = torch.zeros(1, 1, 3, 3)
torch.onnx.export(
model,
# 输入给model的数据,因为是元组类型,因此用括号
(dummy,),
# 导出的onnx文件路径
"demo.onnx",
# 打印导出过程详细信息
verbose=True,
# 为输入和输出节点指定名称,方便后面查看或者操作
input_names=["image"],
output_names=["output"],
# 导出时参考的onnx算子集版本
opset_version=11,
# 设置batch、height、width3个维度是动态的,
# 在onnx中会将其维度赋值为-1,
# 通常,我们只设置batch为动态,其它的避免动态
dynamic_axes={
"image": {0: "batch", 2: "height", 3: "width"},
"output": {0: "batch", 2: "height", 3: "width"},
}
)
print("Done.!")
2. netron可视化
netron可视化可以看到网络输入层为image,输出层为output,这些层名都是在onnx导出时指定的,另外红色框标注处,显示batch、height、width三个维度为动态的。
3. 修改onnx模型
1)修改模型输入尺寸
(1):动态尺寸修改为静态尺寸
import onnx
import onnxruntime as rt
import os
import numpy as np
import argparse
class fix_dim_tools:
def __init__(self, model_path, inputs_shape, inputs_dtype):
assert os.path.exists(model_path), "{} not exists".format(model_path)
if inputs_dtype is None:
print('inputs_dtype is not define, use float for all inputs node')
inputs_dtype = ['float']*len(inputs_shape)
else:
assert len(inputs_shape)==len(inputs_dtype), "inputs shape list should have same length as inputs_dtype"
model = onnx.load(model_path)
self.model = model
self.model_path = model_path
self.inputs_shape = inputs_shape
self.inputs_dtype = inputs_dtype
self.inputs_shape_dict = {}
self.inputs_type_dict = {}
self.outputs_shape_dict = {}
def check_dynamic_input(self):
# check dynamic input and get real input shape
inputs_number = len(self.model.graph.input)
assert inputs_number==len(self.inputs_shape),"model has {} inputs, but {} inputs_shape was given, not match".format(inputs_number,len(self.inputs_shape))
state = False
for i in range(inputs_number):
_input = self.model.graph.input[i]
dim_values = [dim.dim_value for dim in _input.type.tensor_type.shape.dim]
if 0 in dim_values:
state = True
print('Input node:{} is dynamic input, the shape info is {}. Using given shape-{} instead.'.format(_input.name, dim_values, self.inputs_shape[i]))
self.inputs_shape_dict[_input.name] = self.inputs_shape[i]
else:
print('Input node:{} is normal input, the shape info is {}. Ignore given shape-{}'.format(_input.name, dim_values, self.inputs_shape[i]))
self.inputs_shape_dict[_input.name] = dim_values
self.inputs_type_dict[_input.name] = self.inputs_dtype[i]
return state
def run_onnxruntime_to_get_output_shape(self):
sess = rt.InferenceSession(self.model_path)
inputs_dict = {}
# generate fake input
for key in self.inputs_shape_dict.keys():
if self.inputs_type_dict[key] == 'int':
inputs_dict[key] = np.random.randint(0,255,self.inputs_shape_dict[key]).astype(np.uint8)
elif self.inputs_type_dict[key] == 'float':
inputs_dict[key] = np.random.randn(*self.inputs_shape_dict[key]).astype(np.float32)
outputs_name = [node.name for node in sess.get_outputs()]
res = sess.run(outputs_name, inputs_dict)
for i in range(len(sess.get_outputs())):
node_name = sess.get_outputs()[i].name
self.outputs_shape_dict[node_name] = list(res[i].shape)
print("After inference, output:{} shape is [{}]".format(node_name, list(res[i].shape)))
def run(self):
dynamic_input = self.check_dynamic_input()
if dynamic_input is False:
print("{} hasn't dynamic input, no file will be generated.".format(self.model_path))
return False
else:
self.run_onnxruntime_to_get_output_shape()
try:
from onnx.tools import update_model_dims
update_model_dims.update_inputs_outputs_dims(self.model, self.inputs_shape_dict, self.outputs_shape_dict)
return True
except Exception as e:
print("Automatically fix failed. Please try to export an non-dynamic onnx model first. Exception: {}".format(repr(e)))
return False
def export(self,output_path):
onnx.save(self.model, output_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="This script can fix the dim for onnx model which have dynamic inputs.\n"
"Using this file such as: Python freeze_inshape_for_onnx_model.py "
"--onnx_file_path model_in.onnx "
"--output_path model_out.onnx "
"--inputs_shape 1,3,112,112#1,3,64,64 "
"--inputs_dtype float#int")
parser.add_argument('--onnx_file_path', type=str, default='./models/pplcnet.onnx', help='input model path')
parser.add_argument('--output_path', type=str, default='./models/pplcnet_freeze.onnx', help='output model path')
parser.add_argument('--inputs_shape', type=str, default='1,3,224,224', help='inputs shape list, write as 1,3,112,112#1,3,64,64 for multi-input.')
parser.add_argument('--inputs_dtype', type=str, required=False, help='(Options) inputs shape list, write as float#int for multi-input. Defualt as float')
args = parser.parse_args()
inputs_shape = []
inputs_shape_define = args.inputs_shape.split('#')
for shape_str in inputs_shape_define:
inputs_shape.append([int(dim) for dim in shape_str.split(',')])
if args.inputs_dtype is None:
inputs_dtype = None
else:
inputs_dtype = args.inputs_dtype.split('#')
fdt = fix_dim_tools(args.onnx_file_path, inputs_shape, inputs_dtype)
state = fdt.run()
if state is True:
fdt.export(args.output_path)
print('Success! Fix dim done. Export new model to path:{}'.format(args.output_path))
(2)修改batch size
import onnx
import onnx_graphsurgeon as gs
if __name__ == '__main__':
model_path = "./models/onnx/pplcnet_1b.onnx"
output_model_path = "./models/onnx/pplcnet_4b.onnx"
new_batch_size = 4 # 新的批处理大小
model = onnx.load(model_path)
graph = gs.import_onnx(model)
input_node = graph.inputs[0]
input_shape = input_node.shape
input_shape[0] = new_batch_size
input_node.shape = input_shape
output_node = graph.outputs[0]
output_shape = output_node.shape
output_shape[0] = new_batch_size
output_node.shape = output_shape
onnx.save(gs.export_onnx(graph), output_model_path)