TensorFlow中的模型持久化

本文参考《TensorFlow实战Google深度学习框架》一书,总结了一些在TensorFlow在保存训练好的模型过程中使用到的一些API

TF提供了tf.train.Saver类来保存和还原一个神经网络模型

1.模型保存

模型保存的代码如下所示:先声明一个tf.train.Saver对象saver,然后使用saver.save进行保存,该函数的第二个参数是保存的路径。注意保存的文件名后缀为.ckpt。虽然只指定了一个文件路径,但是最终会产生多个文件。
import tensorflow as tf

PATH = 'C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/model.ckpt'

'''part1:保存模型'''
v1 = tf.Variable(tf.constant([1]), name = 'v1')
v2 = tf.Variable(tf.constant([2]), name = 'v2')
result = v1 + v2
init_op = tf.initialize_all_variables()

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    print(sess.run(result))
    saver.save(sess, 'C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/model.ckpt')
产生的文件截图

运行上述代码,就可以将整个网络结构和相关的数据都保存下来。

2.模型恢复

TF中可以使用saver.restore来恢复之前已经保存的模型,以下代码给出了加载这个已经保存的模型的方法:
v1 = tf.Variable(tf.constant([1]), name = 'v1')
v2 = tf.Variable(tf.constant([2]), name = 'v2')
result = v1 + v2

saver = tf.train.Saver()
with tf.Session() as sess2:
    saver.restore(sess2, 'C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/model.ckpt')
    print(sess2.run(result))
需要注意的是,要使用这种方法恢复模型,则v1, v2和result都必须重新声明,否则会报错,而变量的初始化操作则可以不必运行,被换成了加载已经保存了的模型。 此处的输出为[3]。在这个过程中只要name属性保持'v1'和'v2',用于表示变量的标识符是可以改变的。如v1_,v2_。
如果不希望重复定义上面的变量v1, v2, result,则可以使用另外一种方法来加载已经保存的模型。
import tensorflow as tf
saver = tf.train.import_meta_graph('C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/model.ckpt.meta')
with tf.Session() as sess:
    saver.restore(sess,'C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/model.ckpt')
    print('result',sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))
值得注意的是在这个过程中两次使用到的路径名中调用的文件有所不同,model.ckpt.meta保存了计算图的结构。而restore中的路径与前面的例子完全相同。在上面的例子中result的name属性为'add:0',表示其是add这个操作的第一个输出值。我们使用tf.get_default_graph().get_tensor_by_name('add:0')函数来获取当前的计算图和result变量。最后的输出结果是[3]。

在上面的例子中,默认保存和加载了TF计算图上定义的全部变量,但是有的时候可能只需要保存或者加载一部分变量。为了实现这个功能,在声明tf.trian.Saver的对象的时候,可以提供一个列表来指定需要保存或者加载的变量。比如
saver = tf.train.Saver([v1])
在这个过程中就只有变量v1会被保存和加载。如果在这个过程中,想要通过加载已经保存的模型,并且输出v2或者result的值,都会报错,而v1还是可以正常输出[1]。

3.变量重命名

除了可以选取需要被加载的变量,tf.train.Saver类也支持在保存或者加载时给变量重命名。如果先前已经保存了v1和v2两个变量,其name依次为'v1'和'v2', 我们可以在加载的过程中,声明变量的时候对这两个变量进行重命名。该过程如下所示:
v1 = tf.Variable(tf.constant([1]), name = 'other-v1')
v2 = tf.Variable(tf.constant([2]), name = 'other-v2')
saver =  tf.train.Saver({'v1': v1, 'v2': v2})
with tf.Session() as sess2:
    saver.restore(sess2, 'C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/model.ckpt')
    print(sess2.run(v1))
我们分别为v1和v2两个变量重命名为'other-v1', 'other-v2',在加载模型的时候,如果声明对象的过程中tf.train.Saver()括号中为空,则会报找不到变量的错误。为了重命名,在括号中加入了一个字典,将原来name为'v1'的变量保存到当前的v1变量中,而当前的v1变量的名称为‘other-v1’。v2则同理。当然在此处变量也可以以任意合法的标识符来定义,如v1_,v2_,只要保持统一,这样都是合法的。
这样可以方便使用变量的滑动平均值。只需要将影子变量映射到变量自身,那么在使用训练好的模型时就不需要再调用函数来获取变量的滑动平均值了。为了方便加载时重命名滑动平均变量,tf.train.ExponentialMovingAverage类提供了variables_to_restore函数来生成tf.train.Saver类所需要的变量重命名字典。
import tensorflow as tf
v = tf.Variable(0, name = 'v')
ema = tf.train.ExponentialMovingAverage(0.99)
saver = tf.train.Saver(ema.variables_to_restore())

4.其他

在测试或者离线预测时,只需要知道如何从神经网络的输入层经过前向传播得到输出层即可,而不需要变量初始化,模型保存等辅助结点的信息。在迁移学习的过程中,也会遇到类似的情况。TF中提供了convert_variables_to_constants函数,通过这个函数可以将计算图中的变量及其取值通过常量的方式保存,这样整个TF计算图可以统一存放在一个文件中。下面的程序提供了一个例子。
import tensorflow as tf
from tensorflow.python.framework import  graph_util

sess = tf.InteractiveSession()

v1 = tf.Variable(tf.constant(1.0, shape = [1]), name = 'v1')
v2 = tf.Variable(tf.constant(2.0, shape = [1]), name = 'v2')
result = v1 + v2

tf.initialize_all_variables().run()

#'''导出当前计算图的GraphDef部分,只需要这个一个部分就可以完成从输入层到输出层的计算过程'''
graph_def = tf.get_default_graph().as_graph_def()
#将图中的变量及其取值转化为常量,同时将图中不必要的结点去掉(比如变量的初始化操作)
#如果只关心程序中定义的某些计算时,和这些计算无关的节点就没有必要保存了。在下面一行代码中,
#最后一个参数['add']给出了需要保存的节点名称。add节点是上面定义的两个变量相加的操作。
out_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
#将导出的模型存入文件
with tf.gfile.GFile('C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/combined_model.pb', 'wb') as f:
    f.write(out_graph_def.SerializeToString())
生成的文件截图:

通过下面的程序可以直接计算定义的加法运算的结果。当只需要得到计算图中某个节点的取值时,这提供了一个更加方便的方法。
import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    model_filename = 'C:/Users/douyh/Desktop/TensorFlow_Learning/TF_API/combined_model.pb'
    #读取保存的模型文件,并且将文件解析成对应的GraphDef Protocol Buffer
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    #将graph_def中保存的图加载到当前的图中。retrun_elements = ['add:0']给出了返回张量的名称
    #在保存的时候给出的是计算节点的名称'add',在加载的时候给出的是张量的名称,所以是‘add:0'
    result = tf.import_graph_def(graph_def, return_elements = ['add:0'])
    print(sess.run(result))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值