tensorflow中freeze_graph_with_def保存的模型,tables无法初始化问题

在tensorflow自带导出pb模型的接口中,常常会封装graph_util.convert_variables_to_constants这个函数,但由于它会把sess.run(tf.tables_initializer())的操作给删除,不输出操作节点,导致在导入pb模型时会报错

FailedPreconditionError (see above for traceback): Table not initialized

 如freeze_graph.freeze_graph_with_def_protos其实就是在graph_util.convert_variables_to_constants基础上封装了一层,由于其内部又生成了一个session,如果你企图在此函数前把sess.run(tf.tables_initializer())添加到网络图中,在执行freeze_graph.freeze_graph_with_def_protos后就会变成空操作,仍然无法初始化tables

import os

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.ops.lookup_ops import HashTable, KeyValueTensorInitializer

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

OUTPUT_FOLDER = '/tmp'
OUTPUT_NAME = 'hash_table.pb'
OUTPUT_NAMES = ['graph/output', 'init_all_tables']


def build_graph():
    d = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
    init = KeyValueTensorInitializer(d.keys(), d.values())
    hash_table = HashTable(init, default_value=-1)
    data = tf.placeholder(tf.string, (None,), name='data')
    values = hash_table.lookup(data)
    output = tf.identity(values * 2, 'output')


def freeze_graph():
    with tf.Graph().as_default() as graph:
        with tf.name_scope('graph'):
            build_graph()

        with tf.Session(graph=graph) as sess:
            sess.run(tf.tables_initializer())
            print sess.run('graph/output:0', feed_dict={'graph/data:0': ['a', 'b', 'c', 'd', 'e']})
            frozen_graph = convert_variables_to_constants(sess, sess.graph_def, OUTPUT_NAMES)
            tf.train.write_graph(frozen_graph, OUTPUT_FOLDER, OUTPUT_NAME, as_text=False)


def load_frozen_graph():
    with open(os.path.join(OUTPUT_FOLDER, OUTPUT_NAME), 'rb') as f:
        output_graph_def = tf.GraphDef()
        output_graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(output_graph_def, name='')
        with tf.Session(graph=graph) as sess:
            try:
                sess.run(graph.get_operation_by_name('init_all_tables'))
            except KeyError:
                pass
            print sess.run('graph/output:0', feed_dict={'graph/data:0': ['a', 'b', 'c', 'd', 'e']})


if __name__ == '__main__':
    freeze_graph()
    load_frozen_graph()

把init_all_tables当成一个输出节点进行输出,应该直接用graph_util.convert_variables_to_constants这个函数,不要在freeze_graph.freeze_graph_with_def_protos中再生成一个session对话才run

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值