【TensorFlow】TensorFlow模型保存(save)于恢复(restore)的方法总结

什么是Tensorflow模型?

当你训练好一个神经网络后,你会想保存好你的模型便于以后使用并且用于生产。因此,什么是Tensorflow模型?Tensorflow模型主要包含网络设计(或者网络图)和训练好的网络参数的值。所以Tensorflow模型有两个主要的文件:

a) Meta图: 
Meta图是一个协议缓冲区(protocol buffer),它保存了完整的Tensorflow图;比如所有的变量、运算、集合等。这个文件的扩展名是.meta

b) Checkpoint 文件 
这是一个二进制文件,它保存了权重、偏置项、梯度以及其他所有的变量的取值,扩展名为.ckpt。但是, 从0.11版本开始,Tensorflow对改文件做了点修改,checkpoint文件不再是单个.ckpt文件,而是如下两个文件:

mymodel.data-00000-of-00001
mymodel.index

其中, .data文件包含了我们的训练变量。除此之外,还有一个叫checkpoint的文件,它保留了最新的checkpoint文件的记录。

总结一下,对于0.10之后的版本,tensorflow模型包含以下文件:

保存Tensorflow模型

当训练完成后,我们想要保存所有的变量和网络图便于以后使用。因此在Tensorflow中, 为了保存网络图和所有参数的值,我们应该创建tf.train.Saver()这个类的一个对象。

saver = tf.train.Saver()

Tensorflow变量只有在会话(session)中才能激活。因此,要在会话中调用你刚创建的对象的保存方法。

saver.save(sess, "my-test-model")

sess是一个session对象,“my-test-model”是模型的名字。

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model')

# This will save following files in Tensorflow v >= 0.11
# my_test_model.data-00000-of-00001
# my_test_model.index
# my_test_model.meta
# checkpoint

如果我们要在1000次迭代后保存模型,我们应该在调用保存方法时传入步数计数:

saver.save(sess, "my_test_model", global_step=1000)

这会在模型名称后加一个“-1000”并且会创建如下文件:

my_test_model-1000.index
my_test_mod
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值