在使用TensorFlow的过程中,保存模型参数变量是很重要的一个环节,既可以保证训练过程信息不丢失,也可以帮助我们在需要快速恢复或使用一个模型的时候,利用之前保存好的参数之间导入,可以节省大量的训练时间。本文通过最简单的例程教大家如何保存和读取.ckpt文件。
一、保存到文件
首先是导入必要的东西:
import tensorflow as tf
import numpy as np
随便写几个变量:
# Save to file
# remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
init= tf.initialize_all_variables()
定义一个saver,来存储我们的各种变量:
saver = tf.train.Saver()
保存的文件用.ckpt后缀:
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess, "my_net/