ONNX基本操作

1. Pytorch导出ONNX

torch.onnx.export函数实现了pytorch模型到onnx模型的导出,在pytorch1.11.0中,torch.onnx.export函数参数如下:

def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL, 
           input_names=None, output_names=None, aten=False, export_raw_ir=False, 
           operator_export_type=None, opset_version=None, _retain_param_name=True, 
           do_constant_folding=True, example_outputs=None, strip_doc_string=True, 
           dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, 
           enable_onnx_checker=True, use_external_data_format=False):

参数比较多,但常用的有如下几个:

model: pytorch模型

args: 第一个参数model的输入数据,因为模型的输入可能不止一个,因此采用元组作为参数

f: 导出的onnx模型文件路径

export_params: 导出的onnx模型文件可以包含网络结构与权重参数,如果设置该参数为False,则导出的onnx模型文件只包含网络结构,因此,一般保持默认为True即可

verbose: 该参数如果指定为True,则在导出onnx的过程中会打印详细的导出过程信息

input_names: 为输入节点指定名称,因为输入节点可能多个,因此该参数是一个列表

output_names: 为输出节点指定名称,因为输出节点可能多个,因此该参数是一个列表

opset_version: 导出onnx时参考的onnx算子集版本

dynamic_axes: 指定输入输出的张量,哪些维度是动态的,通过用字典的形式进行指定,如果某个张量的某个维度被指定为字符串或者-1,则认为该张量的该维度是动态的,但是一般建议只对batch维度指定动态,这样可提高性能,具体的格式见下面的代码

如下代码,定义了一个包含卷积层、relu激活层的网络,将该网络导出onnx模型,设置了输入、输出的batch、height、width3个维度是动态的

import torch
import torch.nn as nn
import torch.onnx
import os

# 定义一个模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Conv2d(1, 1, 3, padding=1)
        self.relu = nn.ReLU()
        self.conv.weight.data.fill_(1) # 权重被初始化为1
        self.conv.bias.data.fill_(0) # 偏置被初始化为0
    
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x


model = Model()
dummy = torch.zeros(1, 1, 3, 3)

torch.onnx.export(
    model, 

    # 输入给model的数据,因为是元组类型,因此用括号
    (dummy,), 

    # 导出的onnx文件路径
    "demo.onnx", 

    # 打印导出过程详细信息
    verbose=True, 

    # 为输入和输出节点指定名称,方便后面查看或者操作
    input_names=["image"], 
    output_names=["output"], 

    # 导出时参考的onnx算子集版本
    opset_version=11, 

    # 设置batch、height、width3个维度是动态的,
    # 在onnx中会将其维度赋值为-1,
    # 通常,我们只设置batch为动态,其它的避免动态
    dynamic_axes={
        "image": {0: "batch", 2: "height", 3: "width"},
        "output": {0: "batch", 2: "height", 3: "width"},
    }
)

print("Done.!")

 2. netron可视化

 netron可视化可以看到网络输入层为image,输出层为output,这些层名都是在onnx导出时指定的,另外红色框标注处,显示batch、height、width三个维度为动态的。

3. 修改onnx模型

1)修改模型输入尺寸

(1):动态尺寸修改为静态尺寸

import onnx
import onnxruntime as rt
import os
import numpy as np
import argparse

class fix_dim_tools:
    def __init__(self, model_path, inputs_shape, inputs_dtype):
        assert os.path.exists(model_path), "{} not exists".format(model_path)
        if inputs_dtype is None:
            print('inputs_dtype is not define, use float for all inputs node')
            inputs_dtype = ['float']*len(inputs_shape)
        else:
            assert len(inputs_shape)==len(inputs_dtype), "inputs shape list should have same length as inputs_dtype"
    
        model = onnx.load(model_path)
        self.model = model
        self.model_path = model_path
        self.inputs_shape = inputs_shape
        self.inputs_dtype = inputs_dtype

        self.inputs_shape_dict = {}
        self.inputs_type_dict = {}
        self.outputs_shape_dict = {}
    
    def check_dynamic_input(self):
        # check dynamic input and get real input shape
        inputs_number = len(self.model.graph.input)
        assert inputs_number==len(self.inputs_shape),"model has {} inputs, but {} inputs_shape was given, not match".format(inputs_number,len(self.inputs_shape))
        state = False

        for i in range(inputs_number):
            _input = self.model.graph.input[i]
            dim_values = [dim.dim_value for dim in _input.type.tensor_type.shape.dim]
            if 0 in dim_values:
                state = True
                print('Input node:{} is dynamic input, the shape info is {}. Using given shape-{} instead.'.format(_input.name, dim_values, self.inputs_shape[i]))
                self.inputs_shape_dict[_input.name] = self.inputs_shape[i]
            else:
                print('Input node:{} is normal input, the shape info is {}. Ignore given shape-{}'.format(_input.name, dim_values, self.inputs_shape[i]))
                self.inputs_shape_dict[_input.name] = dim_values
            self.inputs_type_dict[_input.name] = self.inputs_dtype[i]
        return state

    def run_onnxruntime_to_get_output_shape(self):
        sess = rt.InferenceSession(self.model_path)
        inputs_dict = {}

        # generate fake input
        for key in self.inputs_shape_dict.keys():
            if self.inputs_type_dict[key] == 'int':
                inputs_dict[key] = np.random.randint(0,255,self.inputs_shape_dict[key]).astype(np.uint8)
            elif self.inputs_type_dict[key] == 'float':
                inputs_dict[key] = np.random.randn(*self.inputs_shape_dict[key]).astype(np.float32)

        outputs_name = [node.name for node in sess.get_outputs()]
        res = sess.run(outputs_name, inputs_dict)
        for i in range(len(sess.get_outputs())):
            node_name = sess.get_outputs()[i].name
            self.outputs_shape_dict[node_name] = list(res[i].shape)
            print("After inference, output:{} shape is [{}]".format(node_name, list(res[i].shape)))

    def run(self):
        dynamic_input = self.check_dynamic_input()

        if dynamic_input is False:
            print("{} hasn't dynamic input, no file will be generated.".format(self.model_path))
            return False
        else:
            self.run_onnxruntime_to_get_output_shape()
    
        try:
            from onnx.tools import update_model_dims
            update_model_dims.update_inputs_outputs_dims(self.model, self.inputs_shape_dict, self.outputs_shape_dict)
            return True
        except Exception as e:
            print("Automatically fix failed. Please try to export an non-dynamic onnx model first. Exception: {}".format(repr(e)))
            return False
    
    def export(self,output_path):
        onnx.save(self.model, output_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="This script can fix the dim for onnx model which have dynamic inputs.\n"
                                                 "Using this file such as: Python freeze_inshape_for_onnx_model.py "
                                                 "--onnx_file_path model_in.onnx "
                                                 "--output_path model_out.onnx "
                                                 "--inputs_shape 1,3,112,112#1,3,64,64 "
                                                 "--inputs_dtype float#int")
    parser.add_argument('--onnx_file_path', type=str, default='./models/pplcnet.onnx', help='input model path')
    parser.add_argument('--output_path', type=str, default='./models/pplcnet_freeze.onnx', help='output model path')
    parser.add_argument('--inputs_shape', type=str, default='1,3,224,224', help='inputs shape list, write as 1,3,112,112#1,3,64,64 for multi-input.')
    parser.add_argument('--inputs_dtype', type=str, required=False, help='(Options) inputs shape list, write as float#int for multi-input. Defualt as float')

    args = parser.parse_args()
    inputs_shape = []
    inputs_shape_define = args.inputs_shape.split('#')
    for shape_str in inputs_shape_define:
        inputs_shape.append([int(dim) for dim in shape_str.split(',')])

    if args.inputs_dtype is None:
        inputs_dtype = None
    else:
        inputs_dtype = args.inputs_dtype.split('#')

    fdt = fix_dim_tools(args.onnx_file_path, inputs_shape, inputs_dtype)
    state = fdt.run()
    if state is True:
        fdt.export(args.output_path)
        print('Success! Fix dim done. Export new model to path:{}'.format(args.output_path))

(2)修改batch size

import onnx
import onnx_graphsurgeon as gs

if __name__ == '__main__':
    
    model_path = "./models/onnx/pplcnet_1b.onnx"
    output_model_path = "./models/onnx/pplcnet_4b.onnx"
    new_batch_size = 4  # 新的批处理大小
     
    model = onnx.load(model_path)
    
    graph = gs.import_onnx(model)
    
    input_node = graph.inputs[0]
    input_shape = input_node.shape
    input_shape[0] = new_batch_size
    input_node.shape = input_shape
    
    output_node = graph.outputs[0]
    output_shape = output_node.shape
    output_shape[0] = new_batch_size
    output_node.shape = output_shape
    
    onnx.save(gs.export_onnx(graph), output_model_path)

  • 3
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

洪流之源

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值