在tensorflow自带导出pb模型的接口中,常常会封装graph_util.convert_variables_to_constants这个函数,但由于它会把sess.run(tf.tables_initializer())的操作给删除,不输出操作节点,导致在导入pb模型时会报错
FailedPreconditionError (see above for traceback): Table not initialized
如freeze_graph.freeze_graph_with_def_protos其实就是在graph_util.convert_variables_to_constants基础上封装了一层,由于其内部又生成了一个session,如果你企图在此函数前把sess.run(tf.tables_initializer())添加到网络图中,在执行freeze_graph.freeze_graph_with_def_protos后就会变成空操作,仍然无法初始化tables
import os
import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.ops.lookup_ops import HashTable, KeyValueTensorInitializer
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
OUTPUT_FOLDER = '/tmp'
OUTPUT_NAME = 'hash_table.pb'
OUTPUT_NAMES = ['graph/output', 'init_all_tables']
def build_graph():
d = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
init = KeyValueTensorInitializer(d.keys(), d.values())
hash_table = HashTable(init, default_value=-1)
data = tf.placeholder(tf.string, (None,), name='data')
values = hash_table.lookup(data)
output = tf.identity(values * 2, 'output')
def freeze_graph():
with tf.Graph().as_default() as graph:
with tf.name_scope('graph'):
build_graph()
with tf.Session(graph=graph) as sess:
sess.run(tf.tables_initializer())
print sess.run('graph/output:0', feed_dict={'graph/data:0': ['a', 'b', 'c', 'd', 'e']})
frozen_graph = convert_variables_to_constants(sess, sess.graph_def, OUTPUT_NAMES)
tf.train.write_graph(frozen_graph, OUTPUT_FOLDER, OUTPUT_NAME, as_text=False)
def load_frozen_graph():
with open(os.path.join(OUTPUT_FOLDER, OUTPUT_NAME), 'rb') as f:
output_graph_def = tf.GraphDef()
output_graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(output_graph_def, name='')
with tf.Session(graph=graph) as sess:
try:
sess.run(graph.get_operation_by_name('init_all_tables'))
except KeyError:
pass
print sess.run('graph/output:0', feed_dict={'graph/data:0': ['a', 'b', 'c', 'd', 'e']})
if __name__ == '__main__':
freeze_graph()
load_frozen_graph()
把init_all_tables当成一个输出节点进行输出,应该直接用graph_util.convert_variables_to_constants这个函数,不要在freeze_graph.freeze_graph_with_def_protos中再生成一个session对话才run