tensorflow转onnx模型

一、准备相应的库

       tensorflow当然是要的。试了下

python:3.6.13  tensorflow:1.14.0   

python 3.7.10 tensorflow:1.15.0, onnx=1.8.1, tf2onnx=1.8.4 是成功的。

如果要安装tensorflow-addons  可以查看https://github.com/tensorflow/addons/

       安装onnx-tensorrt:

可以直接pip安装,如下所示,

pip install -U tf2onnx

也可以git安装 

git clone https://github.com/onnx/onnx-tensorflow.git 
cd onnx-tensorflow 
python setup.py install

二、转pb文件

    一般来说的模型导出方法会将网络信息与权重信息分开存储在不同文件当中,比如ckpt,meta等文件这在部署时候不是很方便。官方提供了一种Freeze Graph的方式,用于将模型相关信息统统打包到一个*.pb文件当中。

         官方提供了相关工具freeze_graph,一般安装完TensorFlow后会自动添加到用户PATH相应的bin目录下,如果没有找到的话可以去TensorFlow源码tensorflow/python/tools/free_graph.py这个位置去找一下,或者直接通过命令行导入module的方式调用。

举例如下,如果有多个输出节点,用逗号隔开:


# 1.直接调用
freeze_graph --input_graph=/home/mnist-tf/graph.proto \
    --input_checkpoint=/home/mnist-tf/ckpt/model.ckpt \
    --output_graph=/tmp/frozen_graph.pb \
    --output_node_names=fc2/add \
    --input_binary=True
 ##fc2/add为输出节点的名字,这个要看一下输出节点的名字是什么
# 2. 通过调用module的方式
python -m tensorflow.python.tools.freeze_graph \
    --input_graph=my_checkpoint_dir/graphdef.pb \
    --input_binary=true \
    --output_node_names=output \
    --input_checkpoint=my_checkpoint_dir \
    --output_graph=frozen.pb
##output是输出节点的名字

获得 

三、pb转onnx

使用以下命令即可生成

python3 -m tf2onnx.convert --input model.pb  --inputs input_img:0[2,384,512,3] --outputs sigmoid_logits:0 --output model.onnx
其中input_img为网络输入的名字,sigmoid_logits为网络输出的名字,[2,384,512,3]表示输入的维度

具体详细参数可以查看https://github.com/onnx/tensorflow-onnx

获取网络输入输出的名字代码如下:

import tensorflow as tf
import os
 
model_dir = '/home/cidi-gpu/disk/data/liyi/tensorflow2onnx/'
model_name = 'LRLane_res2net_highway_6mm.pb'
 
def create_graph():
    with tf.gfile.FastGFile(os.path.join(
            model_dir, model_name), 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
 
create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
print("input_name:"+tensor_name_list[0])
print("output_name:"+tensor_name_list[-1])

四、问题:

1、转成onnx后转trt的时候,出现Assertion failed: kernel_weights.shape.d[1] * ngroup == nchan报错。

  解决:出现这个报错是因为使用tf2onnx工具的版本过低,不支持一些层,之前使用的tf2onnx=1.4,就报错了,后来使用tensorflow=1.15.0, onnx=1.8.1, tf2onnx=1.8.4就没错。而且不同的tf2onnx版本生成的onnx图是不一样的,当转trt报错的时候,可以先尝试一下

           onnx是不是不一致,或许更改转onnx版本,能解决tf不成功地方问题。

  在查找解决办法的时候,也有网友提出以下的解决办法:I have changed that with view() with exact dimensions, for example unsqueeze(-1) changed to view(1,1,36,60). That solved the problem.

 

 

 

 

  • 4
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值