tensorflow之pb文件保存与载入

pb是protocol(协议) buffer(缓冲)的缩写。TensorFlow训练模型后存成的pb文件,是一种表示模型(神经网络)结构的二进制文件,将图中的变量保存成为常量,便于调用,一般无法将pb文件映射成源代码。pb文件的能够保存tensorflow计算图中的操作节点以及对应的各张量,方便我们日后直接调用之前已经训练好的计算图。

注:pb文件可以在训练时直接保存,也可以用.ckpt文件转化为.pb文件。

准备所需的文件:

1、训练完成的.index,.data,.mate文件。

2、编写.ckpt文件转.pb文件的代码。

output_node_names = ["Input/X_placeholder", "Inference/output"]#指定输入输出节点名
def freeze_pb(pb_file,ckpt_path):

    with tf.name_scope('Input'):
        input_data = tf.placeholder(dtype=tf.float32,shape=[None, 784], name='X_placeholder')

    with tf.name_scope('Inference'):
        # batch:20 输入:784,通道:1,输出:10
        W = tf.Variable(initial_value=tf.random_normal(shape=[784,10], stddev=0.01), name='Weights')
        b = tf.Variable(initial_value=tf.zeros(shape=[10]), name='bias')
        print(W)
        logits = tf.matmul(input_data, W) + b
        pred = tf.nn.softmax(logits=logits,name='output')


    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    saver = tf.train.Saver()
    saver.restore(sess, ckpt_path)

    converted_graph_def = tf.graph_util.convert_variables_to_constants(sess,
                                                                       input_graph_def=sess.graph.as_graph_def(),
                                                                       output_node_names=output_node_names)

    with tf.gfile.GFile(pb_file, "wb") as f:
        f.write(converted_graph_def.SerializeToString())

注:

1、在将ckpt文件转为pb文件的时候,一定要将输入节点名称和输出节点名称与设计的网络中的输入节点名称和输出节点名称对应起来,否则会报错。例如这里将

Input/X_placeholder改为Input/X_placeholder_e时会有如下的报错。

AssertionError: Input/X_placeholdere_e is not in graph.

output_node_names = ["Input/X_placeholder", "Inference/output"]指定输入输出。

2、tf.graph_util.convert_variables_to_constants中的output_node_names变量来指定保存的节点名称而不是张量的名称,“Input/X_placeholder:0”是张量的名称而"Input/X_placeholder"表示的是节点的名称。在固化pb文件的时候用节点名称,在调用pb文件是使用张量名称。

执行上述代码后会在指定的目录下生成pb文件。

3、编写调用pb文件代码。

return_enement = ["Input/X_placeholder:0", "Inference/output:0"]#指定输入输出张量名称
def Load_PbFile(pb_file,image):
    imagedata=np.array(image).reshape([1,784])
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True


    with tf.gfile.FastGFile(pb_file, 'rb') as f:
        frozen_graph_def = tf.GraphDef()
        frozen_graph_def.ParseFromString(f.read())
        elements = tf.import_graph_def(frozen_graph_def,return_elements=return_enement)

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        prob=sess.run([elements[1]],feed_dict={elements[0]:imagedata})

        tolist=list(prob[0][0])
        classes = tolist.index(max(tolist))
        print(classes)

这里return_enement = ["Input/X_placeholder:0", "Inference/output:0"]指定的是张量名称,区别于output_node_names = ["Input/X_placeholder", "Inference/output"],否则在运行时会出现TypeError: Cannot interpret feed_dict key as Tensor: Can not convert a Operation into a Tensor.的错误。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值