Pytorch\Onnx\TensorRT

本文介绍了如何将PyTorch模型转换为ONNX格式,并进一步利用TensorRT优化生成engine,以实现深度学习推理的加速。讨论了trtexec命令行工具和Python API的使用,强调了TensorRT对ONNX算子的支持限制,并提到了用于debug的工具和额外的转换方法。
摘要由CSDN通过智能技术生成

前言

TensorRT是NVIDIA推出的一款高效深度学习模型推理框架,其包括了深度学习推理优化器和运行时,能够让深度学习推理应用拥有低时延和高吞吐的优点。

在这里插入图片描述

本质上来讲,就是通过采用对模型中的部分算子进行融合、对特定尺寸的算子选用更好的实现方法,以及使用混合精度等方式,最终加速整个网络的推理速度。

在使用PyTorch训练得到网络模型后,我们希望在模型部署时通过TensorRT加速模型推理,那么可以先将PyTorch模型转为ONNX,然后再讲ONNX转为TensorRT的engine。

实现步骤

PyTorch模型转为ONNX

具体过程可参考 PyTorch模型转ONNX格式_TracelessLe的专栏-CSDN博客

ONNX转TensorRT的engine

  • trtexec 命令
  • https://github.com/NVIDIA/TensorRT/blob/master/samples/trtexec/README.md
trtexec --onnx=net_bs8_v1_simple.onnx --tacticSources=-cublasLt,+cublas --workspace=2048 --fp16 --saveEngine=net_bs8_v1.engine  --verbose

--onnx指定ONNX文件路径
--tacticSources指定使用的方法库
--workspace指定工作空间大小,单位是MB
--fp16 开启FP16模式
--saveEngine 指定生成的engine的保存路径
--verbose 打开verbose模式,更多打印信息。

方法二:基于Python API的engine生成

__author__ = 'TracelessLe'

import os
import tensorrt as trt

ONNX_SIM_MODEL_PATH = 'net_bs8_v1_simple.onnx'
TENSORRT_ENGINE_PATH_PY = 'net_bs8_v1_fp16_py.engine'

def build_engine(onnx_file_path, engine_file_path, flop=16):
    trt_logger = trt.Logger(trt.Logger.VERBOSE)  # trt.Logger.ERROR
    builder = trt.Builder(trt_logger)
    network = builder.create_network(
        1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    )
    
    parser = trt.OnnxParser(network, trt_logger)
    # parse ONNX
    with open(onnx_file_path, 'rb') as model:
        if not parser.parse(model.read()):
            print('ERROR: Failed to parse the ONNX file.')
            for error in range(parser.num_errors):
                print(parser.get_error(error))
            return None
    print("Completed parsing ONNX file")
    builder.max_workspace_size = 2 << 30
    # default = 1 for fixed batch size
    builder.max_batch_size = 1
    # set mixed flop computation for the best performance
    if builder.platform_has_fast_fp16 and flop == 16:
        builder.fp16_mode = True

    if os.path.isfile(engine_file_path):
        try:
            os.remove(engine_file_path)
        except Exception:
            print("Cannot remove existing file: ",
                engine_file_path)

    print("Creating Tensorrt Engine")

    config = builder.create_builder_config()
    config.set_tactic_sources(1 << int(trt.TacticSource.CUBLAS))
    config
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值