Jetson tx2 nx部署weight文件转ONNX再转TRT

1、weights文件转换onnx

我的模型使用的是pytorch搭建,但为方便 训练 使用了pytorch-lightning封装,所以可以算作pytorch-lightning的改编版本。在转换ONNX 文件时候出现了一些问题方便记录:

代码参考:https://blog.csdn.net/qq_37541097/article/details/114847600

import torch
import torch.onnx
import onnx
import onnxruntime
import numpy as np


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
    # return tensor.detach().numpy() if tensor.requires_grad else tensor.numpy()


def main():
    weights_path = './ckpt'  # pytorch-lightning是ckpt, 而pytorch是.pth
    onnx_file_name = 'onnx文件名称.onnx'

    batch_size = 1
    img_channel = 3 # RGB图片
    img_h = # 设置你的
    img_w = # 设置你的

    # input to the model
    # [batch_size, channel, height, width]
    network = # 导入你的模型骨架
    network.load_state_dict(torch.load(weights_path), strict=False)
    # model = Network(network,               #这是pytorch-lightning的调用
    #                 ckpt_path=weights_path)
    model = network.eval()
    x = torch.rand(batch_size, img_channel, img_h, img_w, requires_grad=True)
    torch_out = model(x)

    # export the model
    # using torch:
    torch.onnx.export(model,          # model being run
                      x,                # model input (or a tuple for multiple inputs)
                      onnx_file_name,   # where to save the model (can be a file or file-like object)
                      input_names=['input'],
                      output_names=['output'],
                      verbose=False)

    # model.to_onnx(onnx_file_name, x, export_params=True) #事实证明并不好用,误差很大,在检查那块。推荐使用pytorch的版本

    # check onnx model
    onnx_model = onnx.load(onnx_file_name)
    onnx.checker.check_model(onnx_model, full_check=True)

    ort_session = onnxruntime.InferenceSession(onnx_file_name)

    # compute ONNX Runtime output prediction
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
    ort_outs = ort_session.run(None, ort_inputs)

    # compute ONNX Runtime and Pytorch results
    # assert_allclose: Raises an AssertionError if two objects are not equal up to desired tolerance.
    np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-02, atol=1e-05)

    print("Exported model has been tested with ONNXRuntime, and the result looks good!")

推荐使用pytorch的转换方式, pytorch-lightning似乎不是很稳定,而且国内现在使用的应该比较少,因为记录使用的博客非常浅显。对于lightning的一些深入方法我还没有看到。。。

2、onnx文件转trt文件

参考代码:【python】tensorrt8版本下的onnx转tensorrt engine_onnx转engine-CSDN博客

import tensorrt as trt  # tensorrt installed in JetsonTX2 NX
import os

EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
TRT_LOGGER = trt.Logger()


""" 这里的方法是因为使用tensorrt的trtexec.exe文件转换.onnx文件不成功时"""

def get_engine(onnx_file_path, engine_file_path=''):
    """Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""

    def build_engine():
        """Takes an ONNX file and creates a TensorRT engine to run inference with"""
        with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
            EXPLICIT_BATCH
        ) as network, builder.create_builder_config() as config, trt.OnnxParser(
            network, TRT_LOGGER
        ) as parser, trt.Runtime(
            TRT_LOGGER
        ) as runtime:
            config.max_workspace_size = 1 << 32
            builder.max_batch_size = 1

            # Parser model file
            if not os.path.exists(onnx_file_path):
                print('ONNX file {} not found, please run "main()" first to generate it.'.format(onnx_file_path))
                exit(0)
            print('Loading ONNX file from path {}'.format(onnx_file_path))
            with open(onnx_file_path, 'rb') as model:
                print("Beginning ONNX file parsing")
                if not parser.parse(model.read()):
                    print("ERROR: Failed to parse the ONNX file.")
                    for error in range(parser.num_errors):
                        print(parser.get_error(error))
                    return None

            ## The actual model.onnx is generated with batch size 4, reshape input to batch_size 1
            # network.get_input(0).shape = [1, 3, 768, 768]

            print("Completed parsing of ONNX file")
            print("Building an engine from file {}; this may take a while".format(onnx_file_path))
            plan = builder.build_serialized_network(network, config)
            engine = runtime.deserialize_cuda_engine(plan)
            print("Completed creating Engine")
            with open(engine_file_path, 'wb') as f:
                f.write(plan)
            return engine
    if os.path.exists(engine_file_path):
        # if a serialized engine exists, use it instead of buildering an engine.
        print("Reading engine from file {}".format(engine_file_path))
        with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
            return runtime.deserialize_cuda_engine(f.read())
    else:
        return build_engine()


def ONNX2TRT(onnx_file, engine_file):
    """Create a TensorRT engine for ONNX-based trained_model and run inference"""

    # Try to load a previously generated network graph in ONNX format:
    onnx_file_path = onnx_file
    engine_file_path = engine_file

    get_engine(onnx_file_path, engine_file_path)


if __name__ == '__main__':
    # 用于JetsonTX2 NX板子上
    onnx_file_path ='.../model.onnx'
    engine_file_path = '.../model.trt'
    ONNX2TRT(onnx_file_path, engine_file_path)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值