convert_variables_to_constants()

使用tf.train.Saver会保存运行TensorFlow程序所需要的全部信息,而在测试或离线预测时只需要知道如何由输入层经过前向传播计算得到输出层即可,不需要变量初始化、模型保存等辅助节点的信息。
将变量取值和计算图结构分成不同的文件存储也不方便,因此TensorFlow提供了convert_variables_to_constants函数,通过这个函数可以将计算图中的变量及其取值通过常量的方式保存,这样整个TensorFlow计算图可以统一存放在一个文件中。

1. 将计算图中的变量及其取值通过常量方式保存
1.1生成.pb文件
import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, tf.float32, [1]),name='v1')
v2 = tf.Variable(tf.constant(2.0, tf.float32, [1]),name='v2')
result = v1 + v2

with tf.Session() as sess:
    print(sess.run(tf.global_variables_initializer()))
    print(result.eval())
    # 导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程
    graph_def = tf.get_default_graph().as_graph_def()

    # 将图中的变量及其取值转化为常量,同时将图中不必要的节点去掉(例如一些诸如变量初始化操作的系统运算)
    # 如果只关心程序中定义的某些运算时,和这些计算无关的节点就没有必要导出并保存了,在下面的一行代码中,
    # 最后一个参数['add']给出了需要保存的节点名称.add节点是上面定义的两个变量相加的操作.
    # 注意这里给出的计算节点的名称,所以没有后面的:0,:0表示的是该节点的第一个输出
    output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add'])

    # 将导出的模型存入文件
    with tf.gfile.GFile('./model/combined_model.pb', 'wb') as f:
        f.write(output_graph_def.SerializeToString())
1.2. 读取.pb文件并直接获取节点结果
import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    model_filename = './model/combined_model.pb'
    # 读取保存的模型文件,并将文件解析成对应的GraphDef Protobuf Buffer
    with gfile.FastGFile(model_filename,'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    # 将graph_def中保存的图加载到当前的图中,return_elements=['add:0']给出了返回的张量名称
    # 在保存的时候给出的是计算节点的名称,所以为'add',在加载的时候给出的是张量的名称,所以是'add:0'
    result = tf.import_graph_def(graph_def, return_elements=['add:0'])
    print(sess.run(result))  # 输出: [array([3.], dtype=float32)]
2. 无需重新构造计算图,直接通过训练文件得到.pb文件

先训练,得到训练文件:

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0,tf.float32, [1]),name='v1')
v2 = tf.Variable(tf.constant(2.0,tf.float32, [1]),name='v2')

result = v1 + v2
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(result.eval())
    saver.save(sess,'./model/model.ckpt')

训练文件如下:
在这里插入图片描述
读取训练文件生成.pb文件:

import tensorflow as tf
from tensorflow.python.framework import graph_util

saver = tf.train.import_meta_graph('./model/model.ckpt.meta')
graph_def = tf.get_default_graph().as_graph_def()

with tf.Session() as sess:
    saver.restore(sess, './model/model.ckpt')
    output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add'])
    print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))  # 3.0
    with tf.gfile.GFile('./model/model.pb', 'wb') as f:
        f.write(output_graph_def.SerializeToString())  # 得到文件:model.pb
3. 如果输入是placeholder,如何运行由.pb文件创建的计算图?

构建计算图,并生成.pb文件:

import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.placeholder(tf.float32, shape=[], name='v1')
v2 = tf.placeholder(tf.float32, shape=[], name='v2')
result = v1 + v2

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    result_val = sess.run(result, feed_dict={v1:1, v2:2})
    print(result_val)
    # 导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程
    graph_def = tf.get_default_graph().as_graph_def()

    # 将图中的变量及其取值转化为常量,同时将图中不必要的节点去掉(例如一些诸如变量初始化操作的系统运算)
    # 如果只关心程序中定义的某些运算时,和这些计算无关的节点就没有必要导出并保存了,在下面的一行代码中,
    # 最后一个参数['add']给出了需要保存的节点名称.add节点是上面定义的两个变量相加的操作.
    # 注意这里给出的计算节点的名称,所以没有后面的:0,:0表示的是该节点的第一个输出
    output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add'])

    # 将导出的模型存入文件
    with tf.gfile.GFile('./model/combined_model.pb', 'wb') as f:
        f.write(output_graph_def.SerializeToString())

由.pb文件构建计算图,并进行计算:

import tensorflow as tf
from tensorflow.python.platform import gfile

v1_ph = tf.placeholder(tf.float32,name='v1_ph')
v2_ph = tf.placeholder(tf.float32,name='v2_ph')

with tf.Session() as sess:
    model_filename = './model/combined_model.pb'
    # 读取保存的模型文件,并将文件解析成对应的GraphDef Protobuf Buffer
    with gfile.FastGFile(model_filename,'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

        # 查看计算图上有哪些节点,用于确定网络的输入输出节点名
        node_names_list = [tensor.name for tensor in graph_def.node]
        for node_name in node_names_list:
            print(node_name)

    # 将graph_def中保存的图加载到当前的图中,return_elements=['add:0']给出了返回的张量名称
    # 在保存的时候给出的是计算节点的名称,所以为'add',在加载的时候给出的是张量的名称,所以是'add:0'
    result = tf.import_graph_def(graph_def, input_map={'v1:0':v1_ph, 'v2:0':v2_ph}, return_elements=['add:0'])
    print(sess.run(result, feed_dict={v1_ph:177, v2_ph:12}))  # 输出189.0
Reference
  1. 郑泽宇等.TensorFLow实战Google深度学习框架(第2版),电子工业出版社,2018.
  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值