使用TensorRt搭建自己的模型
前言
在推理过程中,基于 TensorRT 的应用程序的执行速度可比 CPU 平台的速度快 40 倍。借助 TensorRT,您可以优化在所有主要框架中训练的神经网络模型,精确校正低精度,并最终将模型部署到超大规模数据中心、嵌入式或汽车产品平台中。
TensorRT 以 NVIDIA 的并行编程模型 CUDA 为基础构建而成,可帮助您利用 CUDA-X 中的库、开发工具和技术,针对人工智能、自主机器、高性能计算和图形优化所有深度学习框架中的推理。
TensorRT 针对多种深度学习推理应用的生产部署提供 INT8 和 FP16 优化,例如视频流式传输、语音识别、推荐和自然语言处理。推理精度降低后可显著减少应用延迟,这恰巧满足了许多实时服务、自动和嵌入式应用的要求。
一、问题
TensorRTx旨在通过tensorrt网络定义API实现流行的深度学习网络。众所周知,tensorrt内置了解析器,包括caffeparser,uffparser,onnxparser等。但是,当我们使用这些解析器时,经常会遇到一些“不受支持的操作或层”问题,尤其是某些正在使用的最新模型新类型的图层。
二、搭建过程
1.首先非常感谢这位大佬: https://github.com/wang-xinyu/tensorrtx , 牛逼plus ,实现了很多模型 (googlenet、resnet、shuffenetv2、yolov3、yolov4、yolov5等等) ,有兴趣的童鞋可以去瞅瞅, 下面代码实现也都参考这大佬程序改造而成的。但大佬的模型基本都是基于pytorch的,我的模型是基于tf的,但是应该差别不大。
2.开撸前先准备好材料,https://github.com/wdhao/tensorRT_Wheels
这是另一个大佬写的trt的一些层,我们的实现需要参考大佬的代码。
还有一个就是trt的官方文档,有不懂的层就到文档看看怎么使用
https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#raggedsoftmax-layer
接下来开撸:
首先要把tf的模型参数提取到wts中,在大佬的描述中,wts文件组成是这样的
因为大佬的gen_wts.py是pytorch的,我们根据这个写出tf的gen_tf_wts.py,因为tf和trt的数据格式不一样,所以要进行转置,tf的数据格式为NHWC,trt的数据格式为NCHW,代码如下
import tensorflow as tf
from tensorflow.python.framework import tensor_util
import struct
import numpy as np
output_graph_path = r"xxx.pb"
with tf.Session() as sess:
tf.global_variables_initializer().run()
output_graph_def = tf.GraphDef()
graph = tf.get_default_graph()
with open(output_graph_path,"rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def,name="")
graph_nodes = [n for n in output_graph_def.node]
wts = [n for n in graph_nodes if n.op == 'Const']
f = open(r"xxx.wts", "w")
f.write("{}\n".format(len(wts)))
for i,n in enumerate(wts):
#print(n)
print("Name of the node - %s" % n.name