基于tensorrt的不同深度学习框架模型部署

本文介绍了如何将Keras、TensorFlow和PyTorch模型转换为统一的ONNX格式,并详细解释了如何使用TensorRT构建支持动态形状的引擎,以适应不同的部署需求。
摘要由CSDN通过智能技术生成

模型部署
随着深度学习的发展,出现不同的深度学习框架,主要包括tensorflow, keras,pytorch等,不同的框架训练模型在部署一般采用统一的中间表示onnx格式。
一、不同框架模型转onnx;
1)Keras转onnx, 需要用到keras2onnx库, 目前只支持tensorflow 1.x/2.0-2.2,训练模型时须注意;
model = load_model(model_path)
onnx_model = keras2onnx.convert_keras(model, model.name)
output_model_file = ‘*.onnx’
onnx.save_model(onnx_model, temp_model_file)
2)Tensoflow转onnx, 需要用到tf2onnx库, 目前支持tf-1.x or tf-2.x, keras, tensorflow.js and tflite;
loaded_keras_model = load_model(model_path)
onnx_model, _ = tf2onnx.convert.from_keras(loaded_keras_model)
onnx.save(onnx_model, *.onnx’)
3)Torch 转onnx, 采用torch自带的函数;
f = model_path.replace(‘.pth’, ‘.onnx’)
input_names=[“images”]
output_names = [“output0”, “output1”]
if dynamic:
dynamic = {“images”: {0: “batch”, 2: “height”, 3: “width”}}
dynamic[“output0”] = {0: “batch”}
torch.onnx.export(
model.cpu() ,
im.cpu() if dynamic else im,
f,
verbose=True,
opset_version=13,
do_constant_folding=True,
require do_constant_folding=False
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic or None,
)
二、onnx转tensorrt engine
def build_engine(max_batch_size, save_engine, dynamic_shapes, dynamic_batch_size=1):
“”“Takes an ONNX file and creates a TensorRT engine to run inference with”“”
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
with trt.Builder(TRT_LOGGER) as builder,
builder.create_network(EXPLICIT_BATCH) as network,
trt.OnnxParser(network, TRT_LOGGER) as parser:
config = builder.create_builder_config()
builder.max_workspace_size = (1 << 30)*2 # Your workspace size
builder.max_batch_size = max_batch_size
# Parse model file
if not os.path.exists(onnx_file_path):
quit(‘ONNX file {} not found’.format(onnx_file_path))
print(‘Loading ONNX file from path {}…’.format(onnx_file_path))
with open(onnx_file_path, ‘rb’) as model:
print(‘Beginning ONNX file parsing’)
parser.parse(model.read())
if not parser.parse(model.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
print(“=Parsing fail!!!=====“)
else:
print(‘Completed parsing of ONNX file’)
print(‘Building an engine from file {}; this may take a while…’.format(onnx_file_path))
if len(dynamic_shapes) > 0:
builder.max_batch_size = dynamic_batch_size
profile = builder.create_optimization_profile()
for binding_name, dynamic_shape in dynamic_shapes.items():
print(f”
=> set dynamic shape {binding_name}: {dynamic_shape}”)
min_shape, opt_shape, max_shape = dynamic_shape
profile.set_shape(
binding_name, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)
else:
config.add_optimization_profile(profile)
engine = builder.build_engine(network, config)
print(“Completed creating Engine”)
if save_engine:
with open(engine_file_path, “wb”) as f:
f.write(engine.serialize())
return engine

  • 19
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值