TensorRT(11):python版本序列化保存与加载模型


TensorRT系列传送门(不定期更新): 深度框架|TensorRT



楼主曾经在TensorRT(7):python版本使用入门一文中简要记录了python版本是序列化与反序列化加载模型的步骤,但因为环境以及TRT版本不同,API也有相当大的变化,这里重新记录下,在windows下,tensorrt8.2.3.0版本下,调用python的API是如何加载模型的。

实验案例:采用 yolov5的onnx模型,进行FP16量化保存模型。
代码案例均来自 TensorRT提供的sample中。
详细可见TensorRT-8.2.3.0\samples\python
在这里插入图片描述

一、序列化保存模型

与C++端序列化保存模型的步骤类似

  • 1、首先定义个log 文件,然后创建一个runtime
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(TRT_LOGGER)
  • 2、建立builder,设置maxBatchSize参数
builder = trt.Builder(TRT_LOGGER)  # 创建一个builder
builder.max_batch_size = 1
  • 3、配置config,如设置fp16等
config = builder.create_builder_config()  # 创建一个congig
config.max_workspace_size = 1 << 20
config.set_flag(trt.BuilderFlag.FP16)
  • 4、解析onnx文件,并通过config序列化生成一个network
network = builder.create_network(EXPLICIT_BATCH)  # 创建一个network
parser = trt.OnnxParser(network, TRT_LOGGER)

model = open(onnx_file_path, 'rb')
if not parser.parse(model.read()):
    for error in range(parser.num_errors):
        print(parser.get_error(error))

network.get_input(0).shape = [1, 3, 640, 640]
print('Completed parsing of ONNX file')
print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
plan = builder.build_serialized_network(network, config)
with open(engine_file_path, "wb") as f:
      f.write(plan)
      print("Completed write Engine")

二、反序列化加载模型

在一中序列化建立好network后,可以调用deserialize_cuda_engine反序列化生成一个 engine

engine = runtime.deserialize_cuda_engine(plan)
print("Completed creating Engine")

如果加载保存在本地的trt模型,可以直接加载engine

 if os.path.exists(engine_file_path):
      # If a serialized engine exists, use it instead of building an engine.
      print("Reading engine from file {}".format(engine_file_path))
      with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
          return runtime.deserialize_cuda_engine(f.read())

三、完整代码

完整代码都可在github上的官网samples查询。
onnx_to_tensorrt.py


def get_engine(onnx_file_path, engine_file_path=""):
    """Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
    def build_engine():
        """Takes an ONNX file and creates a TensorRT engine to run inference with"""
        with trt.Builder(TRT_LOGGER) as builder, builder.create_network(common.EXPLICIT_BATCH) as network, builder.create_builder_config() as config, trt.OnnxParser(network, TRT_LOGGER) as parser, trt.Runtime(TRT_LOGGER) as runtime:
            config.max_workspace_size = 1 << 28 # 256MiB
            builder.max_batch_size = 1
            # Parse model file
            if not os.path.exists(onnx_file_path):
                print('ONNX file {} not found, please run yolov3_to_onnx.py first to generate it.'.format(onnx_file_path))
                exit(0)
            print('Loading ONNX file from path {}...'.format(onnx_file_path))
            with open(onnx_file_path, 'rb') as model:
                print('Beginning ONNX file parsing')
                if not parser.parse(model.read()):
                    print ('ERROR: Failed to parse the ONNX file.')
                    for error in range(parser.num_errors):
                        print (parser.get_error(error))
                    return None
            # The actual yolov3.onnx is generated with batch size 64. Reshape input to batch size 1
            network.get_input(0).shape = [1, 3, 608, 608]
            print('Completed parsing of ONNX file')
            print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
            plan = builder.build_serialized_network(network, config)
            engine = runtime.deserialize_cuda_engine(plan)
            print("Completed creating Engine")
            with open(engine_file_path, "wb") as f:
                f.write(plan)
            return engine

    if os.path.exists(engine_file_path):
        # If a serialized engine exists, use it instead of building an engine.
        print("Reading engine from file {}".format(engine_file_path))
        with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
            return runtime.deserialize_cuda_engine(f.read())
    else:
        return build_engine()
  • 2
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
《ArcGIS Engine开发从入门到精通》讲解是基于ArcGIS Engine 9.3开发平台,介绍了相关的开发技术和工程应用,并用C#语言编程实现了工程实例。《ArcGIS Engine开发从入门到精通》共4篇分18章,第一篇基础篇(第1~9章)集中介绍了 ArcGIS Engine基础知识,包括开发基础组件对象模型、ArcGIS Engine介绍、基于.NET的ArcGIS Engine的开发,ArcGIS Engine中的控件、框架控件介绍、控件使用实例等,为以后应用ArcGIS Engine的各种接口,快速地实现系统的开发打下坚实的基础;第二篇应用提高篇(第10~12章)介绍了ArcGIS Engine的应用框架、空间分析、ArcGIS Server服务、三维模式数据编辑等高级应用,通过学习这些高级应用可以使读者得心应手地完成各种GIS系统的开发;第三篇综合实例篇(第13章~第14章)用两个综合例子将前面讲解的知识点串起来,让读者将学习的知识点融合起来,以便可以胜任项目开发的角色;第四篇常见疑难解答与经验技巧集萃(第15~18章),本篇将一些开发过程中常见的异常、数据库连接与释放、数据加载以及一些经验技巧做了介绍,本篇的例子主要是对开发过程中常碰到的问题和实战技巧进行了汇总解答,以便帮助读者提高工作效率。, 《ArcGIS Engine开发从入门到精通》从开发者的角度,全面讨论了ArcGIS Engine开发的知识,让读者了解和掌握ArcGIS Engine开发的实战技术,无论是想对ArcGIS Engine入门还是对ArcGIS Engine感兴趣的GIS人员,都能从《ArcGIS Engine开发从入门到精通》中得到提高。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值