最近学习tensorflow,对其中的模型的导入导出一直比较困惑,因此花了些力气研究了一下,最后归纳整理到本篇博客。
变量保存和恢复
在tensorflow中,变量用来存储和更新参数。变量创建时可以赋予name和初始值,并且在执行模型的其他操作之前必须对变量进行初始化。比较简单的一个方法是添加一个对初始化所有变量的操作,在使用模型前先执行这个操作。比如:
#添加变量初始化操作
init_op = tf.initialize_all_variables()
#执行模型前先执行初始化操作
with tf.Session() as sess:
# Run the init operation.
sess.run(init_op)
sess.run(...)
CheckPoint File
Checkpoint文件是用来保存Graph中定义的变量的二进制文件,包含了从变量名和变量值的映射关系。
保存变量
在tensorflow中保存和回复变量的方法是使用tf.train.Saver对象,利用Saver构造器可以给graph的变量添加save和restore的ops,将变量保存或从磁盘读取。
下面是保存变量的一个例子:
# 创建变量
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# 添加一个变量初始化的op
init_op = tf.initialize_all_variables()
# 添加一个ops 保存并恢复全部变量
saver = tf.train.Saver()
# 创建一个Session执行Graph
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
..
# 保存变量
save_path = saver.save(sess, "/tmp/model.ckpt")
print "Model saved in file: ", save_path
恢复变量
同样用Saver可以恢复变量,恢复时不需要进行初始化,但必须提前声明与恢复数据匹配的变量来接收数据
# 创建变量.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print "Model restored."