onnx转trt时,关于动态shape自动配置默认值的脚本

onnx转trt时,关于动态shape自动配置默认值,一般需要指定3个shape,分别是最小最优与最大。但是我们在测试时不想写那么多的代码,能否自动实现3个shape的配置,这里实现了一版。

import os

import tensorrt as trt
import pycuda.driver as cuda
import onnx


def build_engine(onnx_file_path, engine_dest_path, trt_engine_datatype=trt.DataType.HALF, batch_size=1, silent=False, dynamic_shapes={}, max_mem=(1 << 30)):
    """Takes an ONNX file and creates a TensorRT engine to run inference with"""
    
    trt_logger = trt.Logger(trt.Logger.WARNING)
    EXPLICIT_BATCH = [] if trt.__version__[0] < '7' else [1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)]
    with trt.Builder(trt_logger) as builder, builder.create_network(*EXPLICIT_BATCH) as network, trt.OnnxParser(network, trt_logger) as parser:
        builder.max_batch_size = batch_size                                                                                         
        config = builder.create_builder_config()                                                                                        
        config.max_workspace_size = max_mem        # work space
        if trt_engine_datatype == trt.DataType.HALF:        # float 16
            config.set_flag(trt.BuilderFlag.FP16)

        #  Parse model file
        if not os.path.exists(onnx_file_path):
            print('ONNX file {} not found, please run yolov3_to_onnx.py first to generate it.'.format(onnx_file_path))
            exit(0)
       
        print('Loading ONNX file from path {}...'.format(onnx_file_path))
        with open(onnx_file_path, 'rb') as model:
            print('Beginning ONNX file parsing')
            if not parser.parse(model.read()):
                print('ERROR: Failed to parse the ONNX file.')
                for error in range(parser.num_errors):
                    print(parser.get_error(error))
                return None

        print('Completed parsing of ONNX file')
        if not silent:
            print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
        
        dynamic_shapes_fin = {}
        # 获取动态shape
        mod = onnx.load(onnx_file_path)
        for inp in mod.graph.input:
            shape = []
            dynam = False
            for d in inp.type.tensor_type.shape.dim:
                shape.append(d.dim_value)
                if d.dim_param or d.dim_value <= 0:
                    dynam = True
                    # 动态纬度
            # 自动配置动态 shape
            if dynam:
                shape_min = [(i if (i > 0) else 1) for i in shape]
                shape_mid = [(i if (i > 0) else 256) for i in shape]
                shape_max = [(i if (i > 0) else 512) for i in shape]
                dynamic_shapes_fin[inp.name] = [shape_min, shape_mid, shape_max]
            
        # 手动配置动态 batch_size 
        for k, v in dynamic_shapes.items():
            dynamic_shapes_fin[k] = v
        if len(dynamic_shapes_fin) > 0:
            print("===> using dynamic shapes!")
            profile = builder.create_optimization_profile()

            for binding_name, dynamic_shape in dynamic_shapes_fin.items():
                min_shape, opt_shape, max_shape = dynamic_shape
                profile.set_shape(binding_name, min_shape, opt_shape, max_shape)

            config.add_optimization_profile(profile)
            
        trt_engine = builder.build_engine(network, config)
        buf = trt_engine.serialize()
        with open(engine_dest_path, 'wb') as f:
            f.write(buf)

用法,可手动指定,也能不指定,用默认的1、256、512作为测试值用于验证。

build_engine(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx", f"onnx/{project_name}/{project_name}_t2s_encoder.trt",
    # min_shape, opt_shape, max_shape
    dynamic_shapes={
                "ref_seq": [(1, 1), (1, 256), (1, 512)],
                "text_seq": [(1, 1), (1, 256), (1, 512)],
                "ref_bert": [(1024, 1), (1024, 256), (1024, 512)],
                "text_bert": [(1024, 1), (1024, 256), (1024, 512)],
                "ssl_content": [(1, 768, 1), (1, 768, 256), (1, 768, 512)],
                    }
    )
build_engine(f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx", f"onnx/{project_name}/{project_name}_t2s_fsdec.trt")
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值