saver.save和saver.restore

saver()与restore()只是保存了session中的相关变量对应的值,并不涉及模型的结构。

Saver的作用是将我们训练好的模型的参数保存下来,以便下一次继续用于训练或测试;Restore则是将训练好的参数提取出来。Saver类训练完后,是以checkpoints文件形式保存。提取的时候也是从checkpoints文件中恢复变量。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。

一般地,Saver会自动的管理Checkpoints文件。我们可以指定保存最近的N个Checkpoints文件,当然每一步都保存ckpt文件也是可以的,只是没必要,费存储空间。

  • saver()可以选择global_step参数来为ckpt文件名添加数字标记:
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
  • max_to_keep参数定义saver()将自动保存的最近n个ckpt文件,默认n=5,即保存最近的5个检查点ckpt文件。若n=0或者None,则保存所有的ckpt文件。
  • keep_checkpoint_every_n_hours与max_to_keep类似,定义每n小时保存一个ckpt文件。
...
# Create a saver.
saver = tf.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.Session()
for step in xrange(1000000):
    sess.run(..training_op..)
    if step % 1000 == 0:
        # Append the step number to the checkpoint name:
        saver.save(sess, 'my-model', global_step=step)

Restore

restore(sess, save_path)
# sess: A Session to use to restore the parameters.
# save_path: Path where parameters were previously saved.
  • sess: 保存参数的会话。
  • save_path: 保存参数的路径。
  • 当从文件中恢复变量时,不需要事先对他们进行初始化,因为“恢复”自身就是一种初始化变量的方法。
  • 可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:
model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)

在实验中,最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。

saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
  if val_acc>max_acc:
      max_acc=val_acc
      saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
  f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
  if val_acc>max_acc:
      max_acc=val_acc
      saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()

程序后半段代码我们可以改为:

sess=tf.InteractiveSession()  
sess.run(tf.global_variables_initializer())

is_train=False
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
    max_acc=0
    f=open('ckpt/acc.txt','w')
    for i in range(100):
      batch_xs, batch_ys = mnist.train.next_batch(100)
      sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
      val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
      print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
      f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
      if val_acc>max_acc:
          max_acc=val_acc
          saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
    f.close()

#验证阶段
else:
    model_file=tf.train.latest_checkpoint('ckpt/')
    saver.restore(sess,model_file)
    val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
    print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

参考:
https://www.cnblogs.com/denny402/p/6940134.html
https://blog.csdn.net/hellocsz/article/details/89097380

  • 12
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值