直接在代码中打印tensor的名称信息,比如:
1、查看checkpoint 节点信息: 代码如下
from tensorflow.python import pywrap_tensorflow
import os
checkpoint_path = os.path.join( "checkpoint-00454721")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
#print(reader.get_tensor(key))
2、查看checkpoint 节点信息:调用tensorflow 工具,命令如下:
inspect_checkpoint.py --file_name=checkpoint-00454721
以下转自:https://www.cnblogs.com/bonelee/p/8462578.html
查看tensorflow pb模型文件的节点信息:
import tensorflow as tf with tf.Session() as sess: with open('./quantized_model.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) print graph_def
效果:
# ... node { name: "FullyConnected/BiasAdd" op: "BiasAdd" input: "FullyConnected/MatMul" input: "FullyConnected/b/read" attr { key: "T" value { type: DT_FLOAT } } attr { key: "data_format" value { s: "NHWC" } } } node { name: "FullyConnected/Softmax" op: "Softmax" input: "FullyConnected/BiasAdd" attr { key: "T" value { type: DT_FLOAT } } } library { }
参考:https://tang.su/2017/01/export-TensorFlow-network/
https://github.com/tensorflow/tensorflow/issues/15689
一些核心代码:
import tensorflow as tf with tf.Session() as sess: with open('./graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) print graph_def output = tf.import_graph_def(graph_def, return_elements=['out:0']) print(sess.run(output))
This is part of my Tensorflow frozen graph, I have named the input and output nodes.
>>> g.ParseFromString(open('frozen_graph.pb','rb').read())
>>> g
node {
name: "input"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: -1
}
dim {
size: 68
}
}
}
}
}
...
node {
name: "output"
op: "Softmax"
input: "add"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
I ran this model by the following code
(CELL is name of directory where my file is located)
final String MODEL_FILE = "file:///android_asset/" + CELL + "/optimized_graph.pb" ;
final String INPUT_NODE = "input" ;
final String OUTPUT_NODE = "output" ;
final int[] INPUT_SIZE = {1,68} ;
float[] RESULT = new float[8];
inferenceInterface = new TensorFlowInferenceInterface();
inferenceInterface.initializeTensorFlow(getAssets(),MODEL_FILE) ;
inferenceInterface.fillNodeFloat(INPUT_NODE,INPUT_SIZE,input);
and finally
inferenceInterface.readNodeFloat(OUTPUT_NODE,RESULT);