我们都知道tensorflow最后生成的模型文件含:
checkpoint
xxxxx.meta
xxxxx.ckpt.data-xxx
xxxxx.index
学习和使用tensorflow的小伙伴肯定都会进行这个过程,我们来看一下怎么操作
上代码:
import tensorflow as tf
model_name = 'xx/xxx/xxx'
#启动一个会话,意味着开始训练
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
#声明一个saver
saver = tf.train.Saver()
for epoch in range(max_epochs):
training() #训练部分
#保存模型
if (epoch%10) == 0 or (epoch + 1) == max_epochs:
saver.save(sess,model_name)
值得注意的点是,saver的声明要在会话里,不然会报错类似于:使用了未初始化(uninitialized)的变量。
这是最基本的用法,我们不用去考虑saver.save函数的参数问题。我们下一篇文章介绍进阶用法。