def onnx_2_trt(onnx_model_name,trt_model_name):
with trt.Builder(G_LOGGER) as builder, builder.create_network() as network,
trt.OnnxParser(network, G_LOGGER) as parser:
builder.max_batch_size = 1024
builder.max_workspace_size = 2 << 30
print('Loading ONNX file from path {}...'.format(onnx_model_name))
with open(onnx_model_name, 'rb') as model:
print('Beginning ONNX file parsing')
b = parser.parse(model.read())
if 1:
print('Completed parsing of ONNX file')
print('Building an engine from file {}; this may take a while...'.format(args.onnx_model))
####
#builder.int8_mode = True
#builder.int8_calibrator = calib
builder.fp16_mode = True
####
print("layers:",network.num_layers)
network.mark_output(network.get_layer(network.num_layers - 1).get_output(0))//有的模型需要,有的模型在转onnx的之后已经指定了,就不需要这行
engine = builder.build_cuda_engine(network)
print(engine)
print("Completed creating Engine")
with open(trt_model_name, "wb") as f:
f.write(engine.serialize())
return engine
else:
print('Number of errors: {}'.format(parser.num_errors))
error = parser.get_error(0) # if it gets