1. TF---Saver背景介绍
在训练完一个model后hope保存训练的结果(结果:model参数),以便下次迭代的训练或测试, TF针对这一需求提供了Saver类;
(1)Saver类提供了向checkpoints文件保存和从Checkpoints文件中恢复变量的相关方法,Checkpoints文件是一二进制文件,作用:把变量名映射到对应的tensor值。
(2)只要提供一个计数器,但计数器触发时,Saver类可自动生成checkpoint文件,可保存训练过程中多个训练结果,如:保存每一步训练结果…
(3)为避免填充整个磁盘,Saver可自动管理Checkpoints文件,如,可指定保存最近N个Checkpoints文件
In short, Saver类将训练结果保存为checkpoint二进制文件,同时可保存多个checkpoint文件
2. Saver实例------4参数
isTrain: 用来区分训练阶段和测试阶段, True表训练 False 表测试
train_steps: 表 训练次数 example:100
checkpoint_steps: 表训练多少次保存一下checkpoints example:50
checkpoint_dir: 表 checkpoints文件保存路径 example:当前路径
3. 训练阶段
使用Saver.save()方法保存模型:
sess: 当前会话 当前会话记录着当前所有变量值
checkpoint_dir + ‘ model.ckpt ’: 表checkpoint文件的文件名
global_step: 表当前是第几步
训练完成后,保存目录会至少多出4个文件(checkpoint create.py model.ckpt-50 model.ckpt-50.meta …)后面两个文件可能重复
4. 测试阶段
使用saver.restore( ) 方法恢复变量
- sess: 表当前会话,之前保存的结果被加载到这个sess(会话)
- ckpt.model_checkpoint_path:表model存储位置(不需要有模型名字,系统会自动查看ckptpoint 文件,看看最新的是谁,name是什么)
具体流程( 含命令行 )
为了保存训练好的模型参数,以便以后验证和测试,TF提供了tf.train.Saver()模块
模型保存,
1.. 创建Saver对象 同时设置max_to_keep 参数
需要创建一个Saver对象:
saver = tf.train.Saver()
在创建saver对象时,有一参数max_to_keep参数,设置保存模型的个数;默认max_to_keep=5(保存最近的5个模型);若如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0
saver = tf.train.Saver( max_to_keep = 0 )
若只想保存最后一次模型参数,只需saver = tf.train.Saver( max_to_keep = 1 )
2.. 保存模型
创建完Saver对象后,就可以保存模型了
如:saver.save( sess, ‘ckpt/mnist.ckpt’, global_step = step )
第一个参数sess 当前会话 当前会话记录着当前所有变量值
第二个参数指定保存路径和名字
第三个参数:将训练次数作为后缀添加到模型名字后
如:saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
3.. 模型恢复 restore()函数
有两个参数restore( sess, save_path ) save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型
如: model_file = tf.train.latest_checkpoint( ‘ckpt/ ‘ )
saver.restore( sess, model_file )
reference:
https://blog.csdn.net/u011500062/article/details/51728830
http://www.imooc.com/article/33993 结构清晰
http://blog.csdn.net/weixin_38208741/article/details/78812562 例子和细节