显示pb模型中节点的详细信息

79 篇文章 1 订阅
51 篇文章 6 订阅
import tensorflow as tf
from tensorflow.python.framework import graph_util
import argparse

tf.reset_default_graph()  # 重置计算图

def network_structure(args):
    args.model="model.pb"
    model_path = args.model
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        output_graph_def = tf.GraphDef()
        # 获得默认的图
        graph = tf.get_default_graph()
        with open(model_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            _ = tf.import_graph_def(output_graph_def, name="")
            # 得到当前图有几个操作节点
            print("%d ops in the graph." % len(output_graph_def.node))
            op_name = [tensor.name for tensor in output_graph_def.node]
            print(op_name)
            print('=======================================================')
            # 在log_graph文件夹下生产日志文件,可以在tensorboard中可视化模型
            summaryWriter = tf.summary.FileWriter('log_graph_'+args.model, graph)
            cnt = 0
            print("%d tensors in the graph." % len(graph.get_operations()))
            for tensor in graph.get_operations():
                # print出tensor的name和值
       
                print(tensor.name, tensor.values())
                cnt += 1
                if args.n:
                    if cnt == args.n:
                        break
                        
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, help="model name to look")
    parser.add_argument('--n', type=int, help='the number of first several tensor name to look') # 当tensor_name过多
    args = parser.parse_args()
    network_structure(args)

输出结果如下:

32 ops in the graph.
['Placeholder', 'Variable', 'Variable/read', 'Variable_1', 'Variable_1/read', 'Conv2D', 'BiasAdd', 'Relu', 'MaxPool', 'Variable_2', 'Variable_2/read', 'Variable_3', 'Variable_3/read', 'Conv2D_1', 'BiasAdd_1', 'Relu_1', 'MaxPool_1', 'Reshape/shape', 'Reshape', 'Variable_4', 'Variable_4/read', 'Variable_5', 'Variable_5/read', 'MatMul', 'add', 'Relu_2', 'Variable_6', 'Variable_6/read', 'Variable_7', 'Variable_7/read', 'MatMul_1', 'output']
=======================================================
33 tensors in the graph.
init ()
Placeholder (<tf.Tensor 'Placeholder:0' shape=(1, 28, 28, 1) dtype=float32>,)
Variable (<tf.Tensor 'Variable:0' shape=(5, 5, 1, 32) dtype=float32>,)
Variable/read (<tf.Tensor 'Variable/read:0' shape=(5, 5, 1, 32) dtype=float32>,)
Variable_1 (<tf.Tensor 'Variable_1:0' shape=(32,) dtype=float32>,)
Variable_1/read (<tf.Tensor 'Variable_1/read:0' shape=(32,) dtype=float32>,)
Conv2D (<tf.Tensor 'Conv2D:0' shape=(1, 28, 28, 32) dtype=float32>,)
BiasAdd (<tf.Tensor 'BiasAdd:0' shape=(1, 28, 28, 32) dtype=float32>,)
Relu (<tf.Tensor 'Relu:0' shape=(1, 28, 28, 32) dtype=float32>,)
MaxPool (<tf.Tensor 'MaxPool:0' shape=(1, 14, 14, 32) dtype=float32>,)
Variable_2 (<tf.Tensor 'Variable_2:0' shape=(5, 5, 32, 64) dtype=float32>,)
Variable_2/read (<tf.Tensor 'Variable_2/read:0' shape=(5, 5, 32, 64) dtype=float32>,)
Variable_3 (<tf.Tensor 'Variable_3:0' shape=(64,) dtype=float32>,)
Variable_3/read (<tf.Tensor 'Variable_3/read:0' shape=(64,) dtype=float32>,)
Conv2D_1 (<tf.Tensor 'Conv2D_1:0' shape=(1, 14, 14, 64) dtype=float32>,)
BiasAdd_1 (<tf.Tensor 'BiasAdd_1:0' shape=(1, 14, 14, 64) dtype=float32>,)
Relu_1 (<tf.Tensor 'Relu_1:0' shape=(1, 14, 14, 64) dtype=float32>,)
MaxPool_1 (<tf.Tensor 'MaxPool_1:0' shape=(1, 7, 7, 64) dtype=float32>,)
Reshape/shape (<tf.Tensor 'Reshape/shape:0' shape=(2,) dtype=int32>,)
Reshape (<tf.Tensor 'Reshape:0' shape=(1, 3136) dtype=float32>,)
Variable_4 (<tf.Tensor 'Variable_4:0' shape=(3136, 512) dtype=float32>,)
Variable_4/read (<tf.Tensor 'Variable_4/read:0' shape=(3136, 512) dtype=float32>,)
Variable_5 (<tf.Tensor 'Variable_5:0' shape=(512,) dtype=float32>,)
Variable_5/read (<tf.Tensor 'Variable_5/read:0' shape=(512,) dtype=float32>,)
MatMul (<tf.Tensor 'MatMul:0' shape=(1, 512) dtype=float32>,)
add (<tf.Tensor 'add:0' shape=(1, 512) dtype=float32>,)
Relu_2 (<tf.Tensor 'Relu_2:0' shape=(1, 512) dtype=float32>,)
Variable_6 (<tf.Tensor 'Variable_6:0' shape=(512, 10) dtype=float32>,)
Variable_6/read (<tf.Tensor 'Variable_6/read:0' shape=(512, 10) dtype=float32>,)
Variable_7 (<tf.Tensor 'Variable_7:0' shape=(10,) dtype=float32>,)
Variable_7/read (<tf.Tensor 'Variable_7/read:0' shape=(10,) dtype=float32>,)
MatMul_1 (<tf.Tensor 'MatMul_1:0' shape=(1, 10) dtype=float32>,)
output (<tf.Tensor 'output:0' shape=(1, 10) dtype=float32>,)

其中双虚线以上的是operation的名字,虚线以下是tensor的名字。
知道了模型中tensor的名字,我们就可以将具体矩阵赋值给相应tensor,然后通过该模型进行计算了。

input_x = sess.graph.get_tensor_by_name('Placeholder:0')
output= sess.graph.get_tensor_by_name('output:0')
preValue=tf.arg_max(output,1)
preValue = sess.run(preValue, feed_dict={input_x: testPicArr})
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值