1、weights文件转换onnx
我的模型使用的是pytorch搭建,但为方便 训练 使用了pytorch-lightning封装,所以可以算作pytorch-lightning的改编版本。在转换ONNX 文件时候出现了一些问题方便记录:
代码参考:https://blog.csdn.net/qq_37541097/article/details/114847600
import torch
import torch.onnx
import onnx
import onnxruntime
import numpy as np
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# return tensor.detach().numpy() if tensor.requires_grad else tensor.numpy()
def main():
weights_path = './ckpt' # pytorch-lightning是ckpt, 而pytorch是.pth
onnx_file_name = 'onnx文件名称.onnx'
batch_size = 1
img_channel = 3 # RGB图片
img_h = # 设置你的
img_w = # 设置你的
# input to the model
# [batch_size, channel, height, width]
network = # 导入你的模型骨架
network.load_state_dict(torch.load(weights_path), strict=False)
# model = Network(network, #这是pytorch-lightning的调用
# ckpt_path=weights_path)
model = network.eval()
x = torch.rand(batch_size, img_channel, img_h, img_w, requires_grad=True)
torch_out = model(x)
# export the model
# using torch:
torch.onnx.export(model, # model being run
x, # model input (or a tuple for multiple inputs)
onnx_file_name, # where to save the model (can be a file or file-like object)
input_names=['input'],
output_names=['output'],
verbose=False)
# model.to_onnx(onnx_file_name, x, export_params=True) #事实证明并不好用,误差很大,在检查那块。推荐使用pytorch的版本
# check onnx model
onnx_model = onnx.load(onnx_file_name)
onnx.checker.check_model(onnx_model, full_check=True)
ort_session = onnxruntime.InferenceSession(onnx_file_name)
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
# compute ONNX Runtime and Pytorch results
# assert_allclose: Raises an AssertionError if two objects are not equal up to desired tolerance.
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-02, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
推荐使用pytorch的转换方式, pytorch-lightning似乎不是很稳定,而且国内现在使用的应该比较少,因为记录使用的博客非常浅显。对于lightning的一些深入方法我还没有看到。。。
2、onnx文件转trt文件
参考代码:【python】tensorrt8版本下的onnx转tensorrt engine_onnx转engine-CSDN博客
import tensorrt as trt # tensorrt installed in JetsonTX2 NX
import os
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
TRT_LOGGER = trt.Logger()
""" 这里的方法是因为使用tensorrt的trtexec.exe文件转换.onnx文件不成功时"""
def get_engine(onnx_file_path, engine_file_path=''):
"""Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
def build_engine():
"""Takes an ONNX file and creates a TensorRT engine to run inference with"""
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
EXPLICIT_BATCH
) as network, builder.create_builder_config() as config, trt.OnnxParser(
network, TRT_LOGGER
) as parser, trt.Runtime(
TRT_LOGGER
) as runtime:
config.max_workspace_size = 1 << 32
builder.max_batch_size = 1
# Parser model file
if not os.path.exists(onnx_file_path):
print('ONNX file {} not found, please run "main()" first to generate it.'.format(onnx_file_path))
exit(0)
print('Loading ONNX file from path {}'.format(onnx_file_path))
with open(onnx_file_path, 'rb') as model:
print("Beginning ONNX file parsing")
if not parser.parse(model.read()):
print("ERROR: Failed to parse the ONNX file.")
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
## The actual model.onnx is generated with batch size 4, reshape input to batch_size 1
# network.get_input(0).shape = [1, 3, 768, 768]
print("Completed parsing of ONNX file")
print("Building an engine from file {}; this may take a while".format(onnx_file_path))
plan = builder.build_serialized_network(network, config)
engine = runtime.deserialize_cuda_engine(plan)
print("Completed creating Engine")
with open(engine_file_path, 'wb') as f:
f.write(plan)
return engine
if os.path.exists(engine_file_path):
# if a serialized engine exists, use it instead of buildering an engine.
print("Reading engine from file {}".format(engine_file_path))
with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
return runtime.deserialize_cuda_engine(f.read())
else:
return build_engine()
def ONNX2TRT(onnx_file, engine_file):
"""Create a TensorRT engine for ONNX-based trained_model and run inference"""
# Try to load a previously generated network graph in ONNX format:
onnx_file_path = onnx_file
engine_file_path = engine_file
get_engine(onnx_file_path, engine_file_path)
if __name__ == '__main__':
# 用于JetsonTX2 NX板子上
onnx_file_path ='.../model.onnx'
engine_file_path = '.../model.trt'
ONNX2TRT(onnx_file_path, engine_file_path)