tensorflow新建的模型,保存之后会有四个文件:
checkpoint # 用来索引最新的模型文件,模型文件包括以下三种
name.data-00000-of-00001 # 保存模型参数
name.index # 保存模型参数
name.meta # 用以保存模型的图结构
在训练的过程中,我们期待一直保存我们的模型,新保存的模型不会立马覆盖之前的模型,所以需要checkpoint来引导,这样恢复模型的时候可以恢复最新的模型。
所以,一个模型文件包含3个文件:其中meta后缀名的文件会保存模型的计算图结构,其他两个文件会保存参数和变量。
先新建一个模型:
import tensorflow as tf
# Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}
# Define a test operation that we will restore
w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1, name="result")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# Create a saver object which will save all the variables
saver = tf.train.Saver()
# Run the operation by feeding input
print(sess.run(w4, feed_dict))
# Prints 24 which is sum of (w1+w2)*b1
# Now, save the graph
saver.save(sess, 'checkpoint\\my_test_model', global_step=1000)
我们实现的是(w1+w2)*b1的计算操作,首先用placeholder来确定输入接口,意思是w1和w2是我的输入入口,这个输入入口直接和feed_dict对应。描述完计算图之后,就可以新建一个会话,会话就是对接工作,是执行计算的入口。
最后,把会话存起来: saver.save(sess, 'checkpoint\\my_test_model', global_step=1000)
sess表示会话名,'checkpoint\\my_test_model'表示保存路径,global_step会在保存路径后面加上-1000,做一个命名区分。
保存后的模型如下:
my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint
恢复模型
import tensorflow as tf
with tf.Session() as sess:
# First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('checkpoint\\my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('checkpoint/'))
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 13}
# Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("result:0")
print(sess.run(op_to_restore, feed_dict))
# This will print 60 which is calculated
# using new values of w1 and w2 and saved value of b1.