目录
最新补充
C++ 的例子可以参考最新的笔记:TRT_Serialize_and_Deserialize
前言
所有代码均基于 Tensor RT 8.2.5 版本,使用的是 Python 环境。
1--序列化和反序列化的概念
将某个对象的信息转化成可以存储或者传输的信息,这个过程称为序列化;
反序列化是序列化的相反过程,将信息还原为序列化前的状态;
在 Pytorch 中,当序列化为 torch.save( ) 时,则反序列化可以是 torch.load( );
在 Tensor RT 中,为了能够在 inference 的时候不需要重复编译 engine,倾向于将模型 序列化 成一个能够永久保存的 engine;当需要 inference 的时候,只需要通过简单的 反序列化 就能够快速加载 序列化保存好的模型 engine,节省部署开发的时间。
2--代码实现
2-1--序列化并保存模型
# 创建序列化engine
engine = builder.build_serialized_network(network, config)
# 保存序列化保存模型,便于后续直接调用
if True:
saved_trt_path = "./serialize_fcn-resnet101.trt" # 序列化模型保存的地址
with open(saved_trt_path, "wb") as f:
f.write(engine) # 保存序列化模型
2-2--反序列化加载模型
# 反序列化加载模型
f = open(saved_trt_path, "rb") # 打开保存的序列化模型
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(f.read()) # 反序列化加载模型
# 创建context用来执行推断
context = engine.create_execution_context()
2-3--完整代码
import pycuda
import pycuda.driver as cuda
import pycuda.autoinit
import tensorrt as trt
if __name__ == "__main__":
# 创建日志记录器
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
# 显式batch_size,batch_size有显式和隐式之分
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
# 创建builder,用于创建network
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(EXPLICIT_BATCH) # 创建network(初始为空)
# 创建config
config = builder.create_builder_config()
profile = builder.create_optimization_profile() # 创建profile
profile.set_shape("input", (1,3,256,256), (1,3,1026,1282), (1,3,1280,1536)) # 设置动态输入,"input"对应onnx模型的输入"name"
#(1,3,256,256), (1,3,1026,1282), (1,3,1280,1536) 分别对应:最小尺寸、最佳尺寸、最大尺寸
config.add_optimization_profile(profile)
config.max_workspace_size = 1<<30 # 允许TensorRT使用1GB的GPU内存,<<表示左移,左移30位即扩大2^30倍,使用2^30 bytes即 1 GB
# 创建parser用于解析模型
parser = trt.OnnxParser(network, TRT_LOGGER)
# 读取并解析模型
onnx_model_file = "./fcn-resnet101.onnx" # Onnx模型的地址
model = open(onnx_model_file, 'rb')
if not parser.parse(model.read()): # 解析模型
for error in range(parser.num_errors):
print(parser.get_error(error)) # 打印错误(如果解析失败,根据打印的错误进行Debug)
# 创建序列化engine
engine = builder.build_serialized_network(network, config)
# 保存序列化保存模型,便于后续直接调用
if True:
saved_trt_path = "./serialize_fcn-resnet101.trt" # 序列化模型保存的地址
with open(saved_trt_path, "wb") as f:
f.write(engine) # 保存序列化模型
# 反序列化加载模型
f = open(saved_trt_path, "rb") # 打开保存的序列化模型
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(f.read()) # 反序列化加载模型
# 创建context用来执行推断
context = engine.create_execution_context()
'''
...后续步骤
'''