此篇文章是代码分享,TensorRT的基本知识点请见我的博文TensorRT介绍,里面讲述了TensorRT的工作原理和python API的介绍。
环境配置
购买的腾讯云服务器,配置为GPU计算型GN7/8核/32GB/5Mbps
。我的docker容器基本配置如下。之所以在配置中添加opencv-python的版本号,是由于在推理过程中曾遇见bug,降低了opencv-python的版本,bug消除。具体什么bug,已忘。
torch : 1.10.0
cuda :11.3
cudnn :8200
tensorrt :8.2.4.2
opencv-python:4.7.0.72
python 3.8.16
可使用如下代码查看torch,cuda,cudnn和tensorrt的版本信息
import torch
import tensorrt as trt
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())
print(trt.__version__ )
python API 推理
如下代码基于如上环境,实测可行,支持动态输入。我基于mmdeploy编译得到libmmdeploy_tensorrt_ops.so
动态库,推理代码也参考了mmdeploy项目。
import torch
import cv2
import tensorrt as trt
import numpy as np
def trt_version():
return trt.__version__
def torch_device_from_trt(device):
if device == trt.TensorLocation.DEVICE:
return torch.device("cuda")
elif device == trt.TensorLocation.HOST:
return torch.device("cpu")
else:
return TypeError("%s is not supported by torch" % device)
def torch_dtype_from_trt(dtype):
if dtype == trt.int8:
return torch.int8
elif trt_version() >= '7.0' and dtype == trt.bool:
return torch.bool
elif dtype == trt.int32:
return torch.int32
elif dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
else:
raise TypeError("%s is not supported by torch" % dtype)
class TRTModule(torch.nn.Module):
def __init__(self, engine=None, input_names=None, output_names=None):
super(TRTModule, self).__init__()
self.engine = engine
if self.engine is not None:
self.context = self.engine.create_execution_context()
self.input_names = input_names
self.output_names = output_names
def forward(self, *inputs):
batch_size = inputs[0].shape[0]
bindings = [None] * (len(self.input_names) + len(self.output_names))
outputs = [None] * len(self.output_names)
profile_id = 0
input_tensor = inputs[0].contiguous()
for i, input_name in enumerate(self.input_names):
profile = self.engine.get_profile_shape(profile_id, input_name)
assert input_tensor.dim() == len(profile[0]), 'Input dim is different from engine profile.'
for s_min, s_input, s_max in zip(profile[0], input_tensor.shape,profile[2]):
assert s_min <= s_input <= s_max, \
'Input shape should be between ' \
+ f'{profile[0]} and {profile[2]}' \
+ f' but get {tuple(input_tensor.shape)}.'
idx = self.engine.get_binding_index(input_name)
# All input tensors must be gpu variables
assert 'cuda' in input_tensor.device.type
input_tensor = input_tensor.contiguous()
if input_tensor.dtype == torch.long:
input_tensor = input_tensor.int()
self.context.set_binding_shape(idx, tuple(input_tensor.shape))
bindings[idx] = input_tensor.contiguous().data_ptr()
for i, output_name in enumerate(self.output_names):
idx = self.engine.get_binding_index(output_name)#
dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx))
shape = tuple(self.engine.get_binding_shape(idx))
device = torch_device_from_trt(self.engine.get_location(idx))
output = torch.empty(size=shape, dtype=dtype, device=device)
outputs[i] = output
bindings[idx] = output.data_ptr()
self.context.execute_async(batch_size, bindings, torch.cuda.current_stream().cuda_stream)
outputs = tuple(outputs)
if len(outputs) == 1:
outputs = outputs[0]
return outputs
def input_propress(img_path):
"""
对图像的前处理处理
"""
# 读取图像
# Resize
# toRGB
# Normalize
img = img.to('cuda:0')
return img, scale_factor
model_cfg = '**.py'
engine_model = 'end2end.engine'
img_path = '**.jpg'
logger = trt.Logger(trt.Logger.INFO)
import ctypes
ctypes.CDLL('/root/workspace/mmdeploy/mmdeploy/lib/libmmdeploy_tensorrt_ops.so')
trt.init_libnvinfer_plugins(logger, "")
with open(engine_model, "rb") as f, trt.Runtime(logger) as runtime:
engine=runtime.deserialize_cuda_engine(f.read())
for idx in range(engine.num_bindings):
is_input = engine.binding_is_input(idx)
name = engine.get_binding_name(idx)
op_type = engine.get_binding_dtype(idx)
shape = engine.get_binding_shape(idx)
# print('input id:',idx,' is input: ', is_input,' binding name:', name, ' shape:', shape, 'type: ', op_type)
trt_model = TRTModule(engine, ["input"], ["dets", "labels"])
img_input , scale_factor = input_propress(img_path)
result_trt = trt_model(img_input)