模型部署三、支持多维动态batch设置

模型转换支持多维度动态batch设置
import torch
import onnxruntime
import numpy as np

# 模型转换支持多维度动态batch设置
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 3, 3)

    def forward(self, x):
        x = self.conv(x)
        return x

def test_dynamic_axes():
    model = Model()
    dummy_input = torch.rand(1, 3, 10, 10)
    model_names = ['model_static.onnx',
                   'model_dynamic_0.onnx',
                   'model_dynamic_23.onnx']
    # 第0维动态
    dynamic_axes_0 = {
        'in': {0: 'batch'},
        'out': {0: 'batch'}
    }
    # 第2,3维动态batch
    dynamic_axes_23 = {
        'in': {2: 'batch', 3: 'batch'},
        'out': {2: 'batch', 3: 'batch'}
    }

    torch.onnx.export(model,
                      dummy_input,
                      model_names[0],
                      input_names=['in'],
                      output_names=['out'])

    torch.onnx.export(model,
                      dummy_input,
                      model_names[1],
                      input_names=['in'],
                      output_names=['out'],
                      dynamic_axes=dynamic_axes_0)

    torch.onnx.export(model,
                      dummy_input,
                      model_names[2],
                      input_names=['in'],
                      output_names=['out'],
                      dynamic_axes=dynamic_axes_23) #指定输入输出张量的哪些维度是动态的


def test_dynamic_and_static_model_export():
    model = Model()
    origin_tensor = np.random.rand(1, 3, 10, 10).astype(np.float32)
    mult_batch_tensor = np.random.rand(2, 3, 10, 10).astype(np.float32)
    big_tensor = np.random.rand(1, 3, 20, 20).astype(np.float32)

    model_names = ['model_static.onnx',
                   'model_dynamic_0.onnx',
                   'model_dynamic_23.onnx']
    inputs = [origin_tensor, mult_batch_tensor, big_tensor]
    exceptions = dict()

    for model_name in model_names:
        for i, input in enumerate(inputs):
            try:
                ort_session = onnxruntime.InferenceSession(model_name)
                ort_inputs = {'in': input}
                ort_session.run(['out'], ort_inputs)
            except Exception as e:
                exceptions[(i, model_name)] = e
                print(f'Input[{i}] on model {model_name} error.')
            else:
                print(f'Input[{i}] on model {model_name} succeed.')

if __name__ == '__main__':
    # test_dynamic_axes()
    test_dynamic_and_static_model_export()

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值