import tensorflow as tf
from tensorflow.python.framework import graph_util
import argparse
tf.reset_default_graph() # 重置计算图
def network_structure(args):
args.model="model.pb"
model_path = args.model
with tf.Session() as sess:
tf.global_variables_initializer().run()
output_graph_def = tf.GraphDef()
# 获得默认的图
graph = tf.get_default_graph()
with open(model_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
# 得到当前图有几个操作节点
print("%d ops in the graph." % len(output_graph_def.node))
op_name = [tensor.name for tensor in output_graph_def.node]
print(op_name)
print('=======================================================')
# 在log_graph文件夹下生产日志文件,可以在tensorboard中可视化模型
summaryWriter = tf.summary.FileWriter('log_graph_'+args.model, graph)
cnt = 0
print("%d tensors in the graph." % len(graph.get_operations()))
for tensor in graph.get_operations():
# print出tensor的name和值
print(tensor.name, tensor.values())
cnt += 1
if args.n:
if cnt == args.n:
break
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help="model name to look")
parser.add_argument('--n', type=int, help='the number of first several tensor name to look') # 当tensor_name过多
args = parser.parse_args()
network_structure(args)
输出结果如下:
32 ops in the graph.
['Placeholder', 'Variable', 'Variable/read', 'Variable_1', 'Variable_1/read', 'Conv2D', 'BiasAdd', 'Relu', 'MaxPool', 'Variable_2', 'Variable_2/read', 'Variable_3', 'Variable_3/read', 'Conv2D_1', 'BiasAdd_1', 'Relu_1', 'MaxPool_1', 'Reshape/shape', 'Reshape', 'Variable_4', 'Variable_4/read', 'Variable_5', 'Variable_5/read', 'MatMul', 'add', 'Relu_2', 'Variable_6', 'Variable_6/read', 'Variable_7', 'Variable_7/read', 'MatMul_1', 'output']
=======================================================
33 tensors in the graph.
init ()
Placeholder (<tf.Tensor 'Placeholder:0' shape=(1, 28, 28, 1) dtype=float32>,)
Variable (<tf.Tensor 'Variable:0' shape=(5, 5, 1, 32) dtype=float32>,)
Variable/read (<tf.Tensor 'Variable/read:0' shape=(5, 5, 1, 32) dtype=float32>,)
Variable_1 (<tf.Tensor 'Variable_1:0' shape=(32,) dtype=float32>,)
Variable_1/read (<tf.Tensor 'Variable_1/read:0' shape=(32,) dtype=float32>,)
Conv2D (<tf.Tensor 'Conv2D:0' shape=(1, 28, 28, 32) dtype=float32>,)
BiasAdd (<tf.Tensor 'BiasAdd:0' shape=(1, 28, 28, 32) dtype=float32>,)
Relu (<tf.Tensor 'Relu:0' shape=(1, 28, 28, 32) dtype=float32>,)
MaxPool (<tf.Tensor 'MaxPool:0' shape=(1, 14, 14, 32) dtype=float32>,)
Variable_2 (<tf.Tensor 'Variable_2:0' shape=(5, 5, 32, 64) dtype=float32>,)
Variable_2/read (<tf.Tensor 'Variable_2/read:0' shape=(5, 5, 32, 64) dtype=float32>,)
Variable_3 (<tf.Tensor 'Variable_3:0' shape=(64,) dtype=float32>,)
Variable_3/read (<tf.Tensor 'Variable_3/read:0' shape=(64,) dtype=float32>,)
Conv2D_1 (<tf.Tensor 'Conv2D_1:0' shape=(1, 14, 14, 64) dtype=float32>,)
BiasAdd_1 (<tf.Tensor 'BiasAdd_1:0' shape=(1, 14, 14, 64) dtype=float32>,)
Relu_1 (<tf.Tensor 'Relu_1:0' shape=(1, 14, 14, 64) dtype=float32>,)
MaxPool_1 (<tf.Tensor 'MaxPool_1:0' shape=(1, 7, 7, 64) dtype=float32>,)
Reshape/shape (<tf.Tensor 'Reshape/shape:0' shape=(2,) dtype=int32>,)
Reshape (<tf.Tensor 'Reshape:0' shape=(1, 3136) dtype=float32>,)
Variable_4 (<tf.Tensor 'Variable_4:0' shape=(3136, 512) dtype=float32>,)
Variable_4/read (<tf.Tensor 'Variable_4/read:0' shape=(3136, 512) dtype=float32>,)
Variable_5 (<tf.Tensor 'Variable_5:0' shape=(512,) dtype=float32>,)
Variable_5/read (<tf.Tensor 'Variable_5/read:0' shape=(512,) dtype=float32>,)
MatMul (<tf.Tensor 'MatMul:0' shape=(1, 512) dtype=float32>,)
add (<tf.Tensor 'add:0' shape=(1, 512) dtype=float32>,)
Relu_2 (<tf.Tensor 'Relu_2:0' shape=(1, 512) dtype=float32>,)
Variable_6 (<tf.Tensor 'Variable_6:0' shape=(512, 10) dtype=float32>,)
Variable_6/read (<tf.Tensor 'Variable_6/read:0' shape=(512, 10) dtype=float32>,)
Variable_7 (<tf.Tensor 'Variable_7:0' shape=(10,) dtype=float32>,)
Variable_7/read (<tf.Tensor 'Variable_7/read:0' shape=(10,) dtype=float32>,)
MatMul_1 (<tf.Tensor 'MatMul_1:0' shape=(1, 10) dtype=float32>,)
output (<tf.Tensor 'output:0' shape=(1, 10) dtype=float32>,)
其中双虚线以上的是operation的名字,虚线以下是tensor的名字。
知道了模型中tensor的名字,我们就可以将具体矩阵赋值给相应tensor,然后通过该模型进行计算了。
input_x = sess.graph.get_tensor_by_name('Placeholder:0')
output= sess.graph.get_tensor_by_name('output:0')
preValue=tf.arg_max(output,1)
preValue = sess.run(preValue, feed_dict={input_x: testPicArr})