1 测试代码:
$ cat export_nodename.py
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
#coding:utf-8
import tensorflow as tf
import os
model_dir = 'work/CNN/CNN2/training'
model_name = 'dnn.pb'
# 读取并创建一个图graph来存放Google训练好的模型(函数)
def create_graph():
with tf.gfile.GFile(os.path.join(
model_dir, model_name), 'rb') as f:
# 使用tf.GraphDef()定义一个空的Graph
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Imports the graph from graph_def into the current default Graph.
tf.import_graph_def(graph_def, name='')
# 创建graph
create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
print(tensor_name)
~
2 运行代码查看结果