def get_engine(onnx_file_path, engine_file_path=""):
"""Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
def build_engine():
"""Takes an ONNX file and creates a TensorRT engine to run inference with"""
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
EXPLICIT_BATCH
) as network, builder.create_builder_config() as config, trt.OnnxParser(
network, TRT_LOGGER
) as parser, trt.Runtime(
TRT_LOGGER
) as runtime:
config.max_workspace_size = 1 << 32 # 4GB
builder.max_batch_size = 1
config.set_flag(trt.BuilderFlag.FP16)
# Parse model file
if not os.path.exists(onnx_file_path):
print(
"ONNX file {} not found, please run yolov5_to_onnx.py first to generate it.".format(onnx_file_path)
)
exit(0)
print("Loading ONNX file from path {}...".format(onnx_file_path))
with open(onnx_file_path, "rb") as model:
print("Beginning ONNX file parsing")
if not parser.parse(model.read()):
print("ERROR: Failed to parse the ONNX file.")
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
network.get_input(0).shape = [1, 3, 640, 640]
print("Completed parsing of ONNX file")
print("Building an engine from file {}; this may take a while...".format(onnx_file_path))
mmyolo = builder.build_serialized_network(network, config)
engine = runtime.deserialize_cuda_engine(mmyolo)
print("Completed creating Engine")
with open(engine_file_path, "wb") as f:
f.write(mmyolo)
return engine
if os.path.exists(engine_file_path):
# If a serialized engine exists, use it instead of building an engine.
print("Reading engine from file {}".format(engine_file_path))
with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
return runtime.deserialize_cuda_engine(f.read())
else:
return build_engine()
tensorrt8主要在max_batch_size以及fp16_mode等的设置上有所不同,代码如上。参考链接https://blog.csdn.net/weixin_42492254/article/details/125319112
如果出现该错误:
pycuda._driver.LogicError: cuMemcpyHtoDAsync failed: invalid argument
那一定是网络输入维度有问题(还有输入类型也可能是之一,不过我碰到的是维度地),一定要和tensorrt模型输入保持一致,如果不一致的话记得检查输入问题,或者onnx->tensorrt本身维度指定有问题,维度改过来后,千万记得重新生成一遍tensorrt!!!