前提:
- Jetson Nano 【8】 pytorch YOLOv3 直转tensorRT 的测试
- 在使用这份代码的时候,每一次都需要重新转换,一次转换就需要5分钟,于是想着能不能将模型保存下来
思路:
- 1.python类的序列化(显然不太靠谱),我试了一下,果然不太靠谱
- 2.参考tensorRT官方文档
- 3.参考torch2trt官方git
参考tensorRT官方文档(证明在此份代码不可行,但是是可以序列话的)
- 思路1太蠢了,直接掠过,直接看思路2,一开始感觉还是比较靠谱的
- 官方文档传送
- 来到3.4 用python序列化一个模型
- 我们不难找出相关代码
# 序列化
serialized_engine = engine.serialize()
# 序列化并保存
with open(“sample.engine”, “wb”) as f:
f.write(engine.serialize())
# 反序列化
with trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(serialized_engine)
# 从文件中反序列化
with open(“sample.engine”, “rb”) as f, trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
- 那么问题来了,这个engine到底是个啥东西?这要从tensorRT build网络模型开始说起~不过这个我还是不太懂,有机会整理一下,就看看它是怎么来的
- 关于这个内容,我们正好可以借用一下torch2trt的源码
def torch2trt(module,
inputs,
input_names=None,
output_names=None,
log_level=trt.Logger.ERROR,
max_batch_size=1,
fp16_mode=False,
max_workspace_size=0,
strict_type_constraints=False,
keep_network=True,
int8_mode=False,
int8_calib_dataset=None,
int8_calib_algorithm=DEFAULT_CALIBRATION_ALGORITHM):
inputs_in = inputs
# copy inputs to avoid modifications to source data
inputs = [tensor.clone()[0:1] for tensor in inputs] # only run single entry
logger = trt.Logger(log_level)
builder = trt.Builder(logger)
network = builder.create_network()
with ConversionContext(network) as ctx:
if isinstance(inputs, list):
inputs = tuple(inputs)
if not isinstance(inputs, tuple):
inputs = (inputs, )
ctx.add_inputs(inputs, input_names)
outputs = module(*inputs)
if not isinstance(outputs, tuple) and not isinstance(outputs, list):
outputs = (outputs, )
ctx.mark_outputs(outputs, output_names)
builder.max_workspace_size = max_workspace_size
builder.fp16_mode = fp16_mode
builder.max_batch_size = max_batch_size
builder.strict_type_constraints = strict_type_constraints
if int8_mode:
# default to use input tensors for calibration
if int8_calib_dataset is None:
int8_calib_dataset = TensorBatchDataset(inputs_in)
builder.int8_mode = True
# @TODO(jwelsh): Should we set batch_size=max_batch_size? Need to investigate memory consumption
builder.int8_calibrator = DatasetCalibrator(inputs, int8_calib_dataset, batch_size=1, algorithm=int8_calib_algorithm)
engine = builder.build_cuda_engine(network)
module_trt = TRTModule(engine, ctx.input_names, ctx.output_names)
if keep_network:
module_trt.network = network
return module_trt
engine = builder.build_cuda_engine(network)
,简单来说,builder 负责构造网络,而engine也是由builder给build出来的- 于是可以用官方给出的代码序列化,但是由于torch2trt返回的是TRTModule这个类型,所以,我们只能强行转换里面的engine,在build的时候再吧这个engine传回去,但我的测试发现虽然,这个方法可行(表示序列化成功,模型构建成功),但后续预测会报错,具体原因未知。
- 下面是代码,trance
serialized_engine = model_trt.engine.serialize()
with open("speed.engine", "wb") as f:
f.write(serialized_engine)
- 强行转换会报这个错误
# 强行转换会报这个错误
Traceback (most recent call last):
File "/home/nano/Desktop/YOLOv3-Torch2TRT/mydetect.py", line 190, in <module>
detections = non_max_suppression(detections, opt.conf_thres, opt.nms_thres, method=2)
File "/home/nano/Desktop/YOLOv3-Torch2TRT/utils/utils.py", line 254, in non_max_suppression
image_pred = image_pred[(-score).argsort()]
IndexError: too many indices for tensor of dimension 2
- 序列化反序列化结果对比(两个engine除了地址那是一毛一样,但是最后就是没成功,果真有些玄学~),但至少说明序列化应该是没啥毛病
参考torch2trt官方git(这份代码适合,是TRTModule类型)
- 怪我当时没看仔细,ReadME写着了
# 序列化
torch.save(model_trt.state_dict(), 'alexnet_trt.pth')
# 读取
from torch2trt import TRTModule
model_trt = TRTModule()
model_trt.load_state_dict(torch.load('alexnet_trt.pth'))
-
序列化没测,反序列化测了一下
- YOLO v3 tiny 大约24秒加载完毕
- YOLO v3 大约34秒加载完毕
- YOLO v3 spp 大约37秒加载完毕
-
可以正常使用