tensorflow版本为1.4.1
tensorflow提供了Saver类用于模型的保存与导入。该类定义在tensorflow/python/training/saver.py.中。
Saver类的默认初始化函数如下:
__init__(
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=tf.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None
)
由于该初始化函数均有缺省值,因此我们常用的创建一个Saver对象的操作为tf.train.Saver()
下面解释一下常用的参数:
- var_list: Variable/SaveableObject的列表,或者是一个字典(mapping names to SaveableObjects)。默认为None,即保存所有可保存的对象。
- reshape: 当为True时,表示从一个checkpoint中恢复参数时允许参数shape发生变化。当我们reshape了一个变量又希望加载旧模型时,该操作就很有用。
- max_to_keep:为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。该参数用于指定保存最近的N个Checkpoints文件,默认为5.
- keep_checkpoint_every_n_hours: 为了避免填满整个磁盘,Saver可以自动的管理C