一、准备相应的库
tensorflow当然是要的。试了下
python:3.6.13 tensorflow:1.14.0
python 3.7.10 tensorflow:1.15.0, onnx=1.8.1, tf2onnx=1.8.4 是成功的。
如果要安装tensorflow-addons 可以查看https://github.com/tensorflow/addons/
安装onnx-tensorrt:
可以直接pip安装,如下所示,
pip install -U tf2onnx
也可以git安装
git clone https://github.com/onnx/onnx-tensorflow.git
cd onnx-tensorflow
python setup.py install
二、转pb文件
一般来说的模型导出方法会将网络信息与权重信息分开存储在不同文件当中,比如ckpt,meta等文件这在部署时候不是很方便。官方提供了一种Freeze Graph的方式,用于将模型相关信息统统打包到一个*.pb文件当中。
官方提供了相关工具freeze_graph,一般安装完TensorFlow后会自动添加到用户PATH相应的bin目录下,如果没有找到的话可以去TensorFlow源码tensorflow/python/tools/free_graph.py这个位置去找一下,或者直接通过命令行导入module的方式调用。
举例如下,如果有多个输出节点,用逗号隔开:
# 1.直接调用
freeze_graph --input_graph=/home/mnist-tf/graph.proto \
--input_checkpoint=/home/mnist-tf/ckpt/model.ckpt \
--output_graph=/tmp/frozen_graph.pb \
--output_node_names=fc2/add \
--input_binary=True
##fc2/add为输出节点的名字,这个要看一下输出节点的名字是什么
# 2. 通过调用module的方式
python -m tensorflow.python.tools.freeze_graph \
--input_graph=my_checkpoint_dir/graphdef.pb \
--input_binary=true \
--output_node_names=output \
--input_checkpoint=my_checkpoint_dir \
--output_graph=frozen.pb
##output是输出节点的名字
获得
三、pb转onnx
使用以下命令即可生成
python3 -m tf2onnx.convert --input model.pb --inputs input_img:0[2,384,512,3] --outputs sigmoid_logits:0 --output model.onnx
其中input_img为网络输入的名字,sigmoid_logits为网络输出的名字,[2,384,512,3]表示输入的维度
具体详细参数可以查看https://github.com/onnx/tensorflow-onnx
获取网络输入输出的名字代码如下:
import tensorflow as tf
import os
model_dir = '/home/cidi-gpu/disk/data/liyi/tensorflow2onnx/'
model_name = 'LRLane_res2net_highway_6mm.pb'
def create_graph():
with tf.gfile.FastGFile(os.path.join(
model_dir, model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
print("input_name:"+tensor_name_list[0])
print("output_name:"+tensor_name_list[-1])
四、问题:
1、转成onnx后转trt的时候,出现Assertion failed: kernel_weights.shape.d[1] * ngroup == nchan报错。
解决:出现这个报错是因为使用tf2onnx工具的版本过低,不支持一些层,之前使用的tf2onnx=1.4,就报错了,后来使用tensorflow=1.15.0, onnx=1.8.1, tf2onnx=1.8.4就没错。而且不同的tf2onnx版本生成的onnx图是不一样的,当转trt报错的时候,可以先尝试一下
onnx是不是不一致,或许更改转onnx版本,能解决tf不成功地方问题。
在查找解决办法的时候,也有网友提出以下的解决办法:I have changed that with view() with exact dimensions, for example unsqueeze(-1) changed to view(1,1,36,60). That solved the problem.