3、TensorRT学习笔记之ONNX转engine

        摘要:主要学习记录了ONNX转TensorRT流程、代码。末尾有完整代码。

目录:

        3.1 创建TensorRT的日志记录器

        3.2 创建bulider对象 

        3.3 设置engine参数

        3.4 定义network并加载ONNX解析器

        3.5 获取网络的输入输出

        3.6 动态输入

        3.7 检查设备是否支持FP16(半精度)推理

        3.8 写入engine,并序列化model

        3.9 完整代码

        3.10 遇见的问题


3.1 创建TensorRT的日志记录器

log = trt.Logger()

3.2 创建bulider对象 

        使用日志记录器创建 TensorRT Builder 对象,并通过Builder创建network并从该网络生成engine

        其中:trt.OnnxParser(network, log)需要传入两个参数。一个是已创建network,一个是日志记录器

builder = trt.Builder(log)                # 使用日志记录器创建 TensorRT Builder 对象
parser = trt.OnnxParser(network, log)     # 从network生成engine

3.3 设置engine参数

# 创建 Builder Config 对象
config = builder.create_builder_config()            
# 设置 TensorRT 推理引擎使用的最大工作空间大小,单位为字节。指定最大可用显存
config.max_workspace_size = workspace * 1 << 30     

3.4 定义network并加载ONNX解析器

         通过builder创建一个空网络,什么都没有,需要将ONNX的模型结构信息写入创建的空network。

flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network()        # 通过Builder创建network,此时的network还是一个空网络
parser = trt.OnnxParser(network, log)     

         查看ONNX是否解析成功并将ONNX中的模型结构等信息写入network

# 查看是否解析成功,同时将模型结构写进了network
if not parser.parse_from_file(str(onnx)):       
    raise RuntimeError(f'failed to load ONNX file: {onnx}')

3.5 获取网络的输入输出

# 可能不是num_inputs,根据实际情况来。
inputs = [network.get_input(i) for i in range(network.num_inputs)]            
outputs = [network.get_output(i) for i in range(network.num_outputs)]

3.6 动态输入

if dynamic:
    im = torch.zeros(1, 3, *imgsz).to(device)    # 我这儿输入是im = torch.zeros(1,3,640,640)
    if im.shape[0] <= 1:
        # log.warning(f"{trt} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
        print('x')
    profile = builder.create_optimization_profile()
    for inp in inputs:
        profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
    config.add_optimization_profile(profile)

3.7 检查设备是否支持FP16(半精度)推理

其中:

  • builder.platform_has_fast_fp16:用于检查当前设备是否可以进行半精度计算。
  • half:自定义bool参数,用于决定是否半径都推理
  • config.set_flag(trt.BuilderFlag.FP16):set_flag方法来设置config对象的标志,将FP16标志添加到flags中
if builder.platform_has_fast_fp16 and half:
    config.set_flag(trt.BuilderFlag.FP16)

3.8 写入engine,并序列化model

with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
    t.write(engine.serialize())

        如果希望在trt模型中加入classes(其余信息类似)。

with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
    classes = (['person', 'car'])
    add_meta_to_model(t, classes, type='trt')
    t.write(engine.serialize())

3.9 完整代码

import numpy as np
import tensorrt as trt
import torch
import logging

# logger to capture errors, warnings, and other information during the build and inference phases
TRT_LOGGER = trt.Logger()


def build_engine(onnx, dynamic=True, half=True):
    # f = onnx.with_suffix('.engine')
    f = 'trt.engine'
    # 1、创建日志记录器
    log = trt.Logger()
    # 2、创建builder对象
    builder = trt.Builder(log)
    # 3、创建 Builder Config 对象
    config = builder.create_builder_config()
    # 4、将workspace*1 二进制左移30位后的10进制
    workspace = 1
    config.max_workspace_size = workspace * 1 << 30         # 设置 TensorRT 推理引擎使用的最大工作空间大小,单位为字节
    # 5、定义networko并加载ONNX解析器
    flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    network = builder.create_network(flag)
    parser = trt.OnnxParser(network, log)

    if not parser.parse_from_file(str(onnx)):       # 查看是否解析成功
        raise RuntimeError(f'failed to load ONNX file: {onnx}')

    # 6、获得网络的输入输出
    inputs = [network.get_input(i) for i in range(network.num_inputs)]
    outputs = [network.get_output(i) for i in range(network.num_outputs)]

    # 7.判断是否动态输入
    if dynamic:
        im = torch.zeros(1,3,640,640)
        if im.shape[0] <= 1:
            # log.warning(f"{trt} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
            print('x')
        profile = builder.create_optimization_profile()
        for inp in inputs:
            profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
        config.add_optimization_profile(profile)
    # 判断是否支持FP16推理

    if builder.platform_has_fast_fp16 and half:
        config.set_flag(trt.BuilderFlag.FP16)
    # build engine 文件的写入  这里的f是前面定义的engine文件
    with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
        # 序列化model
        t.write(engine.serialize())
    return f, None

if __name__ == '__main__':
    engine, context = build_engine(r'D:\zy\Yolo\yolov8-ZY\yolov8n.onnx')

3.10 遇见的问题

1、AttributeError: 'tensorrt.tensorrt.Builder' object has no attribute 'max_workspace_size'

原因是:tensorrt8.0以上删除了max_workspace_size属性。

  • 降低tensorRT版本到7.x版本
  • 或者如下
config = builder.create_builder_config()            # 创建 Builder Config 对象
config.max_workspace_size = workspace * 1 << 30     # 设置 TensorRT 推理引擎使用的最大工作空间大小,单位为字节

上一篇:2、TensorRT学习笔记之PT转ONNX、可视化ONNX

下一篇:正在学习、持续更新(实战,瑞芯微RK3588部署yolov8检测模型)

参考文章:利用python版tensorRT导出engine【以yolov5为例】_yolov5 export得到的engine和tensorrt的engine-CSDN博客

  • 38
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值