前言
我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。
· Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。
· 只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。
· 为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。
1 实例化对象
创建一个Saver对象:如
saver=tf.train.Saver()
max_to_keep 参数:这个是用来设置保存模型的个数,表明保存的最大checkpoint文件数。当一个新文件创建的时候,旧文件就会被删掉。默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:
saver=tf.train.Saver(max_to_keep=0)
keep_checkpoint_every_n_hour: 除了保存最近的max_to_keep_checkpoint文件,你还可能想每训练N小时保存一个checkpoint文件。这将是非常有用的,如果你想分析一个模型在很长的一段训练时间内是怎么改变的。例如,设置keep_checkpoint_every_n_hour=2确保每训练2个小时保存一个checkpoint文件。
2 保存训练过程中或者训练好的, 模型图及权重参数
2.1保存训练模型
saver.save(sess=sess, save_path=model_save_path, global_step=step)
第一个参数sess=sess, 会话名字;
第二个参数save_path=model_save_path, 设定权重参数保存到的路径和文件名;
第三个参数global_step=step, 将训练的次数作为后缀加入到模型名字中,表示当前是第几步
2.2 查看保存
训练完成后,当前目录底下会多出几个文件
- 打开名为“checkpoint”的文件,可以看到保存记录,和最新的模型存储位置
- 权重等参数被保存到model.ckpt.data文件中,以字典的形式;
- model.ckpt-index, 是内部需要的某种索引来正确映射前两个文件;
- 图和元数据被保存到model.ckpt.meta文件中,可以使用tf.train.import_meta_graph加载
3. 重载模型的图及权重参数(模型恢复)
重载模型的参数,继续训练或用于测试数据
saver.restore(sess=sess, save_path = model_save_path)
- 第一个参数sess=sess, 会话名字
- 第二个参数save_path=model_save_path, 权重参数的保存路径和文件名,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫做什么。