模型的保存与加载是通过
Saver类来实现的
1.创建一个Saver对象
saver = tf.train.Saver()
在创建这个Saver对象的时候,有一个常用的参数max_to_keep, 用来设置保存模型的个数,默认为5;如果只想保存最后一代的模型,只需要将max_to_keep设置为1即可。即:
saver = tf.train.Saver(max_to_keep = 1)
2.调用save函数,保存模型
saver.save(sess,
save_path,
global_step = None,
latest_filename = None,
meta_graph_suffix = 'meta',
write_meta_graph = True,
write_state = True)
其中主要的参数有sess, save_path, global_step接下来分别介绍:
sess: 保存模型要求必须有一个加载了计算图的会话,而且所有变量必须已经被初始化
save_path:模型保存的路径及保存名称,即一个完整的路径包括地址和文件名:root/files
global_step:用于区分不同训练阶段的结果,如果提供,会被添加到save_path后面
3.调用restore函数,加载模型
saver.restore(sess, save_path)
sess: 加载模型必须有一个加载了计算图的会话,而且所有变量必须已经被初始化
save_path:待加载的模型的路径
4.自动获取最后一次保存的模型
ckpt = tf.train.latest_checkpoint('ckpt/')
saver.restore(sess, ckpt)