使用tf.train.Saver
会保存运行TensorFlow程序所需要的全部信息,而在测试或离线预测时只需要知道如何由输入层经过前向传播计算得到输出层即可,不需要变量初始化、模型保存等辅助节点的信息。
将变量取值和计算图结构分成不同的文件存储也不方便,因此TensorFlow提供了convert_variables_to_constants
函数,通过这个函数可以将计算图中的变量及其取值通过常量的方式保存,这样整个TensorFlow计算图可以统一存放在一个文件中。
1. 将计算图中的变量及其取值通过常量方式保存
1.1生成.pb文件
import tensorflow as tf
from tensorflow.python.framework import graph_util
v1 = tf.Variable(tf.constant(1.0, tf.float32, [1]),name='v1')
v2 = tf.Variable(tf.constant(2.0, tf.float32, [1]),name='v2')
result = v1 + v2
with tf.Session() as sess:
print(sess.run(tf.global_variables_initializer()))
print(result.eval())
# 导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程
graph_def = tf.get_default_graph().as_graph_def()
# 将图中的变量及其取值转化为常量,同时将图中不必要的节点去掉(例如一些诸如变量初始化操作的系统运算)
# 如果只关心程序中定义的某些运算时,和这些计算无关的节点就没有必要导出并保存了,在下面的一行代码中,
# 最后一个参数['add']给出了需要保存的节点名称.add节点是上面定义的两个变量相加的操作.
# 注意这里给出的计算节点的名称,所以没有后面的:0,:0表示的是该节点的第一个输出
output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add'])
# 将导出的模型存入文件
with tf.gfile.GFile('./model/combined_model.pb', 'wb') as f:
f.write(output_graph_def.SerializeToString())
1.2. 读取.pb文件并直接获取节点结果
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
model_filename = './model/combined_model.pb'
# 读取保存的模型文件,并将文件解析成对应的GraphDef Protobuf Buffer
with gfile.FastGFile(model_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 将graph_def中保存的图加载到当前的图中,return_elements=['add:0']给出了返回的张量名称
# 在保存的时候给出的是计算节点的名称,所以为'add',在加载的时候给出的是张量的名称,所以是'add:0'
result = tf.import_graph_def(graph_def, return_elements=['add:0'])
print(sess.run(result)) # 输出: [array([3.], dtype=float32)]
2. 无需重新构造计算图,直接通过训练文件得到.pb文件
先训练,得到训练文件:
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0,tf.float32, [1]),name='v1')
v2 = tf.Variable(tf.constant(2.0,tf.float32, [1]),name='v2')
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(result.eval())
saver.save(sess,'./model/model.ckpt')
训练文件如下:
读取训练文件生成.pb文件:
import tensorflow as tf
from tensorflow.python.framework import graph_util
saver = tf.train.import_meta_graph('./model/model.ckpt.meta')
graph_def = tf.get_default_graph().as_graph_def()
with tf.Session() as sess:
saver.restore(sess, './model/model.ckpt')
output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add'])
print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0'))) # 3.0
with tf.gfile.GFile('./model/model.pb', 'wb') as f:
f.write(output_graph_def.SerializeToString()) # 得到文件:model.pb
3. 如果输入是placeholder,如何运行由.pb文件创建的计算图?
构建计算图,并生成.pb文件:
import tensorflow as tf
from tensorflow.python.framework import graph_util
v1 = tf.placeholder(tf.float32, shape=[], name='v1')
v2 = tf.placeholder(tf.float32, shape=[], name='v2')
result = v1 + v2
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
result_val = sess.run(result, feed_dict={v1:1, v2:2})
print(result_val)
# 导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程
graph_def = tf.get_default_graph().as_graph_def()
# 将图中的变量及其取值转化为常量,同时将图中不必要的节点去掉(例如一些诸如变量初始化操作的系统运算)
# 如果只关心程序中定义的某些运算时,和这些计算无关的节点就没有必要导出并保存了,在下面的一行代码中,
# 最后一个参数['add']给出了需要保存的节点名称.add节点是上面定义的两个变量相加的操作.
# 注意这里给出的计算节点的名称,所以没有后面的:0,:0表示的是该节点的第一个输出
output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add'])
# 将导出的模型存入文件
with tf.gfile.GFile('./model/combined_model.pb', 'wb') as f:
f.write(output_graph_def.SerializeToString())
由.pb文件构建计算图,并进行计算:
import tensorflow as tf
from tensorflow.python.platform import gfile
v1_ph = tf.placeholder(tf.float32,name='v1_ph')
v2_ph = tf.placeholder(tf.float32,name='v2_ph')
with tf.Session() as sess:
model_filename = './model/combined_model.pb'
# 读取保存的模型文件,并将文件解析成对应的GraphDef Protobuf Buffer
with gfile.FastGFile(model_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 查看计算图上有哪些节点,用于确定网络的输入输出节点名
node_names_list = [tensor.name for tensor in graph_def.node]
for node_name in node_names_list:
print(node_name)
# 将graph_def中保存的图加载到当前的图中,return_elements=['add:0']给出了返回的张量名称
# 在保存的时候给出的是计算节点的名称,所以为'add',在加载的时候给出的是张量的名称,所以是'add:0'
result = tf.import_graph_def(graph_def, input_map={'v1:0':v1_ph, 'v2:0':v2_ph}, return_elements=['add:0'])
print(sess.run(result, feed_dict={v1_ph:177, v2_ph:12})) # 输出189.0
Reference
- 郑泽宇等.TensorFLow实战Google深度学习框架(第2版),电子工业出版社,2018.