onnx设置动态batch/修改onnx的batch

  • 一. 训练结束后(例如pytorch)导出onnx模型时, 设置动态batch
  • def export_onnx(model, input_hwc, output_file, input_names=['input'],
                output_names=['output'], show=False, opset_version=9, dynamic=True):
        if dynamic:
            dynamic_axes = {
                'input': {
                    0: 'batch',
                },
                'output': {
                    0: 'batch'
                }
            }
        else:
            dynamic_axes = {}
    
        h, w, c = input_hwc
        input_shape = [1, c, h, w]
        one_img = torch.randn(input_shape)
    
        # register_extra_symbolics(opset_version)
        with torch.no_grad():
            torch.onnx.export(
                model.cpu().eval(),
                one_img,
                output_file,
                input_names=input_names,
                output_names=output_names,
                export_params=True,
                keep_initializers_as_inputs=True,
                verbose=show,
                dynamic_axes=dynamic_axes,
                opset_version=opset_version)
        print('>>> finish export onnx:', output_file)
    

    二.

  • 通过onnx库修改onnx模型的batch
  • # 安装onnx:pip install onnx
    import onnx
    def change_input_dim(model):
        # Use some symbolic name not used for any other dimension
        sym_batch_dim = "N"
        # or an actal value
        actual_batch_dim = "4" 
    
        # The following code changes the first dimension of every input to be batch-dim
        # Modify as appropriate ... note that this requires all inputs to
        # have the same batch_dim 
        inputs = model.graph.input
        for input in inputs:
            # Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
            # Add checks as needed.
            dim1 = input.type.tensor_type.shape.dim[0]
            # update dim to be a symbolic value
            dim1.dim_param = sym_batch_dim
            # or update it to be an actual value:
            # dim1.dim_value = actual_batch_dim
    
    
    def apply(transform, infile, outfile):
        model = onnx.load(infile)
        transform(model)
        onnx.save(model, outfile)
    
    apply(change_input_dim, onnx_pth, save_pth)
    
    

三.动态推理。

# 安装onnxruntime: pip install onnxruntime
import onnxruntime as ort
import numpy as np
import torch 
x1 = torch.rand(1,3,112,112)
ort_sess1 = ort.InferenceSession(path_to_onnx_model)
outputs1 = ort_sess1.run(None, {'input.1': x2.numpy()})

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

josiechen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值