class tf.train.Saver
保存和恢复变量
最简单的保存和恢复模型的方法是使用tf.train.Saver 对象。构造器给graph 的所有变量,或是定义在列表里的变量,添加save 和 restore ops。saver 对象提供了方法来运行这些ops,定义检查点文件的读写路径。
检查点是专门格式的二进制文件,将变量name 映射到 tensor value。检查checkpoin 内容最好的方法是使用Saver 加载它。
Savers 可以使用提供的计数器自动计数checkpoint 文件名。这可以是你在训练一个模型时,在不同的步骤维持多个checkpoint。例如你可以使用 training step number 计数checkpoint 文件名。为了避免填满硬盘,savers 自动管理checkpoint 文件。例如,你可以最多维持N个最近的文件,或者没训练N小时保存一个checkpoint.
通过传递一个值给可选参数 global_step ,你可以编号checkpoint 名字。
saver.save(sess, 'my-model', global_step=0) ==>filename: 'my-model-0'
saver.save(sess, 'my-model', global_step=1000) ==>filename: 'my-model-1000'
另外,Saver() 构造器可选的参数可以让你控制硬盘上 checkpoint 文件的数量。
- max_to_keep: 表明保存的最大checkpoint 文件数。当一个新文件创建的时候,旧文件就会被删掉。如果值为None或0,表示保存所有的checkpoint 文件。默认值为5(也就是说,保存最近的5个checkpoint 文件)。
- keep_checkpoint_every_n_hour: 除了保存最近的max_to_keep checkpoint 文件,你还可能想每训练N小时保存一个checkpoint 文件。这将是非常有用的,如果你想分析一个模型在很长的一段训练时间内是怎么改变的。例如,设置 keep_checkpoint_every_n_hour=2 确保没训练2个小时保存一个checkpoint 文件。默认值10000小时无法看到特征。
一个定期保存的训练程序如下这样:
#Create a saver
saver=tf.train.Saver(...variables...)
#Launch the graph and train, saving the model every 1,000 steps.
sess=tf.Session()
for step in xrange(1000000):
sess.run(...training_op...)
if step % 1000 ==0:
#Append the step number to the checkpoint name:
saver.save(sess,'my-model',global_step=step)
除了checkpoint 文件之外,savers 还在硬盘上保存了一个协议缓存,存储最近的checkpoint 列表。这用于管理 被编号的checkpoint 文件,并且通过latest_checkpoint() 可以很容易找到最近的checkpoint 的路径。协议缓存存储在紧挨checkpoint 文件的名为 'checkpoint' 的文件中。
如果你创建了几个savers,你可以调用save() 指定协议缓存的文件名。
tf.train.Saver.__init__(var_list=None, reshape=False, shared=False, max_to_keep=5, keep_checkpoint_every_n_hour=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None)
创建一个Saver
构造器添加操作去保存和恢复变量。
var_list 指定了将要保存和恢复的变量。它可以传dict 或者list
- 变量名字的dict: key 是将用来在checkpoint 文件中存储和恢复的变量的名称。
- 变量的list: 变量的 op name
v1=tf.Variable(..., name='v1')
v2=tf.Variable(..., name='v2')
# Pass the variables as a dict:
saver=tf.train..Saver({'v1':v1, 'v2':v2})
# Or pass them as a list
saver=tf.train..Saver([v1,v2])
# Passing a list is equivalent to passing a dict with the variable op names as keys:
saver=tf.train..Saver({v.op.name: v for v in [v1,v2]})
- sess: 用于保存变量的Session
- save_path: checkpoint 文件的路径。如果saver 是共享的,这是共享checkpoint 文件名的前缀。
- global_step: 如果提供了global step number,将会追加到 save_path 后面去创建checkpoint 的文件名。可选参数可以是一个Tensor,一个name Tensor或integer Tensor.
#!/usr/bin/env python
# coding=utf-8
import os
import tensorflow as tf
# Create some variables.
v1=tf.Variable([[1,1],[2,2],[3,3]],name="v1")
v2=tf.Variable([[4,4],[5,5],[6,7]],name="v2")
# Add an op to initialize the variables.
init_op=tf.initialize_all_variables()
# Add ops to save and restore all the variables.
saver=tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, save the variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
save_path=saver.save(sess,"/home/yhk/tmp/test/model.ckpt")
print "Model saved in file: ", save_path
- sess: 用于恢复参数的Session
- save_path: 参数之前保存的路径