class tf.train.Saver
保存和恢复变量
最简单的保存和恢复模型的方法是使用tf.train.Saver 对象。构造器给graph 的所有变量,或是定义在列表里的变量,添加save 和 restore ops。saver 对象提供了方法来运行这些ops,定义检查点文件的读写路径。
检查点是专门格式的二进制文件,将变量name 映射到 tensor value。检查checkpoin 内容最好的方法是使用Saver 加载它。
Savers 可以使用提供的计数器自动计数checkpoint 文件名。这可以是你在训练一个模型时,在不同的步骤维持多个checkpoint。例如你可以使用 training step number 计数checkpoint 文件名。为了避免填满硬盘,savers 自动管理checkpoint 文件。例如,你可以最多维持N个最近的文件,或者没训练N小时保存一个checkpoint.
通过传递一个值给可选参数 global_step ,你可以编号checkpoint 名字。
-
saver.save(sess, 'my-model', global_step=0) ==>filename: 'my-model-0'
-
saver.save(sess, 'my-model', global_step=1000) ==>filename: 'my-model-1000'
另外,Saver() 构造器可选的参数可以让你控制硬盘上 checkpoint 文件的数量。
- max_to_keep: 表明保存的最大checkpoint 文件数。当一个新文件创建的时候,旧文件就会被删掉。如果值为None或0,表示保存所有的checkpoint 文件。默认值为5(也就是说,保存最近的5个checkpoint 文件)。
- keep_checkpoint_every_n_hour: 除了保存最近的max_to_keep checkpoint 文件,你还可能想每训练N小时保存一个checkpoint 文件。这将是非常有用的,如果你想分析一个模型在很长的一段训练时间内是怎么改变的。例如,设置 keep_checkpoint_every_n_hour=2 确保没训练2个小时保存一个checkpoint 文件。默认值10000小时无法看到特征。
注意,你仍然必须调用save() 方法去保存模型。传递这些参数给构造器并不会自动为你保存这些变量。
一个定期保存的训练程序如下这样:
-
#Create a saver
-
saver=tf.train.Saver(...variables...)
-
#Launch the graph and train, saving the model every 1,000 steps.
-
sess=tf.Session()
-
for step in xrange(1000000):
-
sess.run(...training_op...)
-
if step % 1000 ==0:
-
#Append the step number to the checkpoint name:
-
saver.save(sess,'my-model',global_step=step)
除了checkpoint 文件之外,savers 还在硬盘上保存了一个协议缓存,存储最近的checkpoint 列表。这用于管理 被编号的checkpoint 文件,并且通过latest_checkpoint() 可以很容易找到最近的checkpoint 的路径。协议缓存存储在紧挨checkpoint 文件的名为 'checkpoint' 的文件中。
如果你创建了几个savers,你可以调用save() 指定协议缓存的文件名。
tf.train.Saver.__init__(var_list=None, reshape=False, shared=False, max_to_keep=5, keep_checkpoint_every_n_hour=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None)
创建一个Saver
构造器添加操作去保存和恢复变量。
var_list 指定了将要保存和恢复的变量。它可以传dict 或者list
- 变量名字的dict: key 是将用来在checkpoint 文件中存储和恢复的变量的名称。
- 变量的list: 变量的 op name
例如:
-
v1=tf.Variable(..., name='v1')
-
v2=tf.Variable(..., name='v2')
-
# Pass the variables as a dict:
-
saver=tf.train..Saver({'v1':v1, 'v2':v2})
-
# Or pass them as a list
-
saver=tf.train..Saver([v1,v2])
-
# Passing a list is equivalent to passing a dict with the variable op names as keys:
-
saver=tf.train..Saver({v.op.name: v for v in [v1,v2]})
可选参数 reshape ,如果为True,允许从保存文件中恢复一个不同shape 的变量,但元素的数量和type一致。如果你reshap 了一个变量而又想从一个旧的文件中恢复,这是非常有用的。
可选参数 shared,如果为True,通知每个设备上共享的checkpoint.
tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True)
保存变量
这个方法运行通过构造器添加的操作。它需要启动图的session。被保存的变量必须经过了初始化。
方法返回新建的checkpoint 文件的路径。路径可以直接传给restore() 进行调用。
参数:
- sess: 用于保存变量的Session
- save_path: checkpoint 文件的路径。如果saver 是共享的,这是共享checkpoint 文件名的前缀。
- global_step: 如果提供了global step number,将会追加到 save_path 后面去创建checkpoint 的文件名。可选参数可以是一个Tensor,一个name Tensor或integer Tensor.
返回值:
一个字符串:保存变量的路径。如果saver 是被共享的,字符串以'-?????-of-nnnnn' 结尾。'nnnnn' 是共享的数目。
保存变量
用tf.train.Saver() 创建一个Saver 来管理模型中的所有变量。
-
#!/usr/bin/env python
-
# coding=utf-8
-
import os
-
import tensorflow as tf
-
# Create some variables.
-
v1=tf.Variable([[1,1],[2,2],[3,3]],name="v1")
-
v2=tf.Variable([[4,4],[5,5],[6,7]],name="v2")
-
# Add an op to initialize the variables.
-
init_op=tf.initialize_all_variables()
-
# Add ops to save and restore all the variables.
-
saver=tf.train.Saver()
-
# Later, launch the model, initialize the variables, do some work, save the variables to disk.
-
with tf.Session() as sess:
-
sess.run(init_op)
-
# Do some work with the model.
-
save_path=saver.save(sess,"/home/yhk/tmp/test/model.ckpt")
-
print "Model saved in file: ", save_path
如果你不给tf.train.Saver() 传入任何参数,那么server 将处理graph 中的所有变量。其中每一个变量都以变量创建时传入的名称被保存。
tf.train.Saver.restore(sess, save_path)
恢复之前保存的变量
这个方法运行构造器为恢复变量所添加的操作。它需要启动图的Session。恢复的变量不需要经过初始化,恢复作为初始化的一种方法。
save_path 参数是之前调用save() 的返回值,或调用 latest_checkpoint() 的返回值。
参数:
- sess: 用于恢复参数的Session
- save_path: 参数之前保存的路径
reference:https://blog.csdn.net/qiqiaiairen/article/details/53184216