onnx转trt时,关于动态shape自动配置默认值,一般需要指定3个shape,分别是最小最优与最大。但是我们在测试时不想写那么多的代码,能否自动实现3个shape的配置,这里实现了一版。
def build_engine(onnx_file_path, engine_dest_path, trt_engine_datatype=trt.DataType.HALF, batch_size=1, silent=False, dynamic_shapes={}, values_of_static_shape={}, max_mem=(1 << 30)):
"""Takes an ONNX file and creates a TensorRT engine to run inference with"""
trt_logger = trt.Logger(trt.Logger.VERBOSE)
EXPLICIT_BATCH = [] if trt.__version__[0] < '7' else [1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)]
with trt.Builder(trt_logger) as builder, builder.create_network(*EXPLICIT_BATCH) as network, trt.OnnxParser(network, trt_logger) as parser:
builder.max_batch_size = batch_size
config:trt.IBuilderConfig = builder.create_builder_config()
config.max_workspace_size = max_mem # work space
if trt_engine_datatype == trt.DataType.HALF: # float 16
config.set_flag(trt.BuilderFlag.FP16)
# Parse model file
if not os.path.exists(onnx_file_path):
print('ONNX file {} not found, please run yolov3_to_onnx.py first to generate it.'.format(onnx_file_path))
exit(0)
print('Loading ONNX file from path {}...'.format(onnx_file_path))
with open(onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
if not parser.parse(model.read()):
print('ERROR: Failed to parse the ONNX file.')
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
print('Completed parsing of ONNX file')
if not silent:
print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
dynamic_shapes_fin = {}
# 获取动态shape
mod = onnx.load(onnx_file_path)
for index, inp in enumerate(mod.graph.input):
shape = []
dynam = False
for d in inp.type.tensor_type.shape.dim:
shape.append(d.dim_value)
if d.dim_param or d.dim_value <= 0:
dynam = True
# 动态纬度
# 自动配置动态 shape
if dynam:
shape_min = [(i if (i > 0) else 2) for i in shape]
shape_mid = [(i if (i > 0) else 256) for i in shape]
shape_max = [(i if (i > 0) else 512) for i in shape]
dynamic_shapes_fin[inp.name] = [shape_min, shape_mid, shape_max]
# 手动配置动态 batch_size
for k, v in dynamic_shapes.items():
dynamic_shapes_fin[k] = v
profile = builder.create_optimization_profile()
if len(dynamic_shapes_fin) > 0:
# print("===> using dynamic shapes!")
for binding_name, dynamic_shape in dynamic_shapes_fin.items():
min_shape, opt_shape, max_shape = dynamic_shape
profile.set_shape(binding_name, min_shape, opt_shape, max_shape)
for ten_name, (min_value, opt_value, max_value) in values_of_static_shape.items():
profile.set_shape_input(ten_name, min_value, opt_value, max_value)
config.add_optimization_profile(profile)
trt_engine = builder.build_engine(network, config)
buf = trt_engine.serialize()
with open(engine_dest_path, 'wb') as f:
f.write(buf)
用法,可手动指定,也能不指定,用默认的1、256、512作为测试值用于验证。
build_engine(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx", f"onnx/{project_name}/{project_name}_t2s_encoder.trt",
# min_shape, opt_shape, max_shape
dynamic_shapes={
"ref_seq": [(1, 1), (1, 256), (1, 512)],
"text_seq": [(1, 1), (1, 256), (1, 512)],
"ref_bert": [(1024, 1), (1024, 256), (1024, 512)],
"text_bert": [(1024, 1), (1024, 256), (1024, 512)],
"ssl_content": [(1, 768, 1), (1, 768, 256), (1, 768, 512)],
},
values_of_static_shape={
"iy_len":[(165,), (265,), (1165,)],
"ikv_len":[(227,), (327,), (1227,)]
}
)
build_engine(f"onnx/{project_name}/{project_name}_t2s_fsdec.onnx", f"onnx/{project_name}/{project_name}_t2s_fsdec.trt")