前面介绍了通过使用tf.train.Saver函数来保存TensorFlow程序的参数,但是,在使用tf.train.Saver函数保存模型文件的时候,是保存所有的参数信息,而有些时候我们并不需要所有的参数信息。我们只需要知道神经网络的输入层经过前向传播计算得到输出层即可,所以在保存的时候,我们也不需要保存所有的参数,以及变量的初始化、模型保存等辅助节点信息与迁移学习类似。之前使用tf.train.Saver函数保存模型文件的时候会产生多个文件,它将变量的取值和计算图结构分成了不同的文件存储。TensorFlow提供了另一种保存模型文件的方法,将计算图保存在一个文件中。
1、模型文件的保存
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
if __name__ == "__main__":
a = tf.Variable(tf.constant(5.,shape=[1]),name="a")
b = tf.Variable(tf.constant(6.,shape=[1]),name="b")
c = a + b
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
#导出当前计算图的GraphDef部分
graph_def = tf.get_default_graph().as_graph_def()
#保存指定的节点,并将节点值保存为常数
output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add'])
#将计算图写入到模型文件中
model_f = tf.gfile.GFile("model.pb","wb")
model_f.write(output_graph_def.SerializeToString())
convert_variables_to_constants函数,会将计算图中的变量取值以常量的形式保存。在保存模型文件的时候,我们只是导出了GraphDef部分,GraphDef保存了从输入层到输出层的计算过程。在保存的时候,通过convert_variables_to_constants函数来指定保存的
节点名称而不是
张量的名称,
“add:0”是张量的名称而
"add"表示的是节点的名称。
2、模型文件的读取
sess = tf.Session()
#将保存的模型文件解析为GraphDef
model_f = gfile.FastGFile("model.pb",'rb')
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_f.read())
c = tf.import_graph_def(graph_def,return_elements=["add:0"])
print(sess.run(c))
#[array([ 11.], dtype=float32)]
在读取模型文件获取变量的值的时候,我们需要指定的是
张量的名称而不是
节点的名称,这两个地方需要特别注意一下。在读取模型文件的时候,可能会遇到一个错误
tensorflow.python.framework.errors_impl.NotFoundError: NewRandomAccessFile failed to Create/Open: ./model.pd,打开模型文件的时候报错,所以一般都是模型文件的名称有问题,要注意统一。在这个地方,我遇到了一个很奇葩的问题,我读取模型文件的名称和保存的文件名称是一样,但是还是报错,后面我将保存模型文件的名称复制到读取模型文件的名称那里,就解决了这个问题。