安装tensorRT:
1、下载与电脑中cuda和cudnn版本对应的tensorRT(比如我的是TensorRT-8.2.1.8.Windows10.x86_64.cuda-11.4.cudnn8.2)
2、打开目录里面有python文件夹,找到对应python版本的whl文件(我的是tensorrt-8.2.1.8-cp38-none-win_amd64.whl) 因为我python安装的是3.8版本
3、终端安装:pip install tensorrt-8.2.1.8-cp38-none-win_amd64.whl
4、结束
import tensorrt as trt
def get_DynEngine(onnx_file_path, engine_file_path,patchsize,max_workspace_size,max_batch_size):
'''
Attempts to load a serialized engine if available,
otherwise build a new TensorRT engine as save it
'''
TRT_LOGGER = trt.Logger()
trt.init_libnvinfer_plugins(TRT_LOGGER, "")
explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(explicit_batch)
config = builder.create_builder_config()
parser = trt.OnnxParser(network, TRT_LOGGER)
runtime = trt.Runtime(TRT_LOGGER)
print("common.EXPLICIT_BATCH:", explicit_batch)
# 最大内存占用
# 显存溢出需要重新设置
config.max_workspace_size = max_workspace_size # 256MB
config.set_flag(trt.BuilderFlag.FP16)
print("max_workspace_size:", config.max_workspace_size)
builder.max_batch_size = max_batch_size # 推理的时候要保证batch_size<=max_batch_size
if not os.path.exists(onnx_file_path):
print(f'onnx file {onnx_file_path} not found,please run torch_2_onnx.py first to generate it')
exit(0)
print(f'Loading ONNX file from path {onnx_file_path}...')
with open(onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
if not parser.parse(model.read()):
print('ERROR:Failed to parse the ONNX file')
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
inputs = [network.get_input(i) for i in range(network.num_inputs)]
print("input", inputs)
outputs = [network.get_output(i) for i in range(network.num_outputs)]
print("out:", outputs)
print("Network Description")
for input in inputs:
# 获取当前转化之前的 输入的 batch_size
batch_size = input.shape[0]
print("Input '{}' with shape {} and dtype {} . ".format(input.name, input.shape, input.dtype))
for output in outputs:
print("Output '{}' with shape {} and dtype {} . ".format(output.name, output.shape, output.dtype))
# Dynamic input setting 动态输入在builder的profile设置
# 为每个动态输入绑定一个profile
profile = builder.create_optimization_profile()
print("network.get_input(0).name:", network.get_input(0).name)
profile.set_shape(network.get_input(0).name, (1,1, *patchsize), (1, 1,*patchsize),
(max_batch_size, 1, *patchsize)) # 最小的尺寸,常用的尺寸,最大的尺寸,推理时候输入需要在这个范围内
config.add_optimization_profile(profile)
print('Completed parsing the ONNX file')
print(f'Building an engine from file {onnx_file_path}; this may take a while...')
engine = builder.build_serialized_network(network, config)
print('Completed creating Engine')
with open(engine_file_path, 'wb') as f:
f.write(engine)
return engine
if __name__ == "__main__":
get_DynEngine("1.onnx", "2.engine",[96,160,160],5*(1<<30),2)