TensorFlow查看pb模型节点信息

 显示pb模型节点信息:

import tensorflow as tf
import os
 
model_dir = './'
model_name = 'ocr.pb'

tf.reset_default_graph()

def create_graph():
    with tf.gfile.FastGFile(os.path.join(model_dir, model_name), 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
 
create_graph()
 
for node in tf.get_default_graph().as_graph_def().node:
    # print(node)
    print(node.op) # Placeholder
    print(node.name)
    print(node.attr)
    inputs = [str(in_name) for in_name in node.input]
    break

部分转载自:TensorFlow查看输入节点和输出节点名称_miao0967020148的专栏-CSDN博客

保存简单网络为pb文件

import tensorflow as tf
import numpy as np

data_shape = [1, 3, 32, 32]
filter_shape = [3, 3, 3, 64]
strides = [1, 1, 1, 1]
padding = "SAME"
dilations = [1, 1, 1, 1]
dtype = "float32"
data_format = "NCHW"

input = tf.placeholder(dtype=dtype, shape=data_shape)
filter = tf.placeholder(dtype=dtype, shape=filter_shape)

np.random.seed(0)
input_np = np.array(np.random.rand(*data_shape), dtype=dtype)
filter_np = np.array(np.random.rand(*filter_shape), dtype=dtype)

conv1_out = tf.nn.conv2d(input, filter, strides=strides, padding=padding,
                         data_format=data_format, dilations=dilations, name="conv2d_1")

def freeze_graph(output_node_names, output_graph_name):
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    pbtxt_model_name = output_graph_name[:-3]+".pbtxt"
    tf.io.write_graph(input_graph_def, './', pbtxt_model_name)

    with tf.Session() as sess:
        output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess=sess, input_graph_def=input_graph_def,
                                                                                  output_node_names=output_node_names.split(","))
        with tf.gfile.GFile(output_graph_name, "wb") as f:
            f.write(output_graph_def.SerializeToString())
 
with tf.Session() as session:
    result1 = session.run(conv1_out, feed_dict={input: input_np, filter: filter_np})
    print(result1)

    freeze_graph("conv2d_1", "conv2d_test.pb")

        # graph = tf.get_default_graph()
        # graph_def = graph.as_graph_def()
        graph_def = sess.graph_def
 

        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess=sess,
            input_graph_def=input_graph_def,
            output_node_names=output_names)
 
        with tf.gfile.GFile(dump_graph_path, "wb") as f:
            f.write(output_graph_def.SerializeToString())

pb pbtxt读取

graph_def = tf.compat.v1.GraphDef()
# pb
with tf.io.gfile.GFile(model_path, "rb") as f:
    graph_def.ParseFromString(f.read())
# pbtxt
with open(model_path, "r") as pf:
    text_format.Parse(pf.read(), graph_def)

tf2 frozen graph推理

参考:https://github.com/tensorflow/ngraph-bridge/blob/master/examples/infer_image.py

import time
import tensorflow as tf
import numpy as np
import os
import pdb
import sys

data_shape = (1, 512, 512, 3)
dtype = np.float32
model_file = "./model.pb"

test_data = np.random.random(data_shape).astype(dtype)

def load_graph(model_file):
    graph = tf.Graph()
    graph_def = tf.compat.v1.GraphDef()

    with open(model_file, "rb") as f:
        graph_def.ParseFromString(f.read())
    with graph.as_default():
        tf.import_graph_def(graph_def)

    return graph

graph = load_graph(model_file)

input_name = "import/" + "input_1"
output_name = "import/" + "output_1"
input_operation = graph.get_operation_by_name(input_name)
output_operation = graph.get_operation_by_name(output_name)

print('start infer ............................')
with tf.compat.v1.Session(graph=graph) as sess:
    # Warmup
    results = sess.run(output_operation.outputs[0],
                       {input_operation.outputs[0]: test_data})

    results = np.squeeze(results)
    print(results)

tensorflow实现将ckpt转pb文件_pan_jinquan的博客-CSDN博客_ckpt转pb

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Luchang-Li

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值