tensorflow中,模型的存储和恢复使用tf.train.Saver类,模型存储使用该类的 save 方法。模型恢复使用restore 方法。
模型存储
模型存储使用tf.train.Saver.save()方法。以saver.save(sess, 'model/model.ckpt')
为例,在model路径下会有四个文件(如下图)
- checkpoint 记录保存信息,通过它可以定位最新保存的模型;
- *.meta 保存当前图结构;
- *.index 保存当前参数名;
- *.data 保存当前参数名。
import tensorflow as tf
a = tf.get_variable('a', shape=[3], initializer=tf.constant_initializer(1))
b = tf.get_variable('b', shape=[5], initializer=tf.constant_initializer(2))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
saver.save(sess, 'model/model.ckpt')
模型恢复
模型恢复使用tf.train.Saver.restore() 方法。
```
import tensorflow as tf
a = tf.get_variable('a', shape=[3])
b = tf.get_variable('b', shape=[5])
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, 'model/model.ckpt')
print(sess.run(a)) # [ 1. 1. 1.]
print(sess.run(b)) # [ 2. 2. 2. 2. 2.]
```
保存和恢复部分变量
使用 save 方法存储模型时,若不指定参数,则 Saver 会处理图中所有的变量。每个变量都保存在创建变量时所传递的名称下。我们还可以对指定变量进行存储和恢复。示例如下:
- save
import tensorflow as tf
a = tf.get_variable('a', shape=[3], initializer=tf.constant_initializer(1))
b = tf.get_variable('b', shape=[5], initializer=tf.constant_initializer(2))
saver = tf.train.Saver({"a": a})
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
saver.save(sess, 'model/model.ckpt')
- restore
import tensorflow as tf
a = tf.get_variable('a', shape=[3])
# b = tf.get_variable('b', shape=[5])
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, 'model/model.ckpt')
print(sess.run(a))
# print(sess.run(b))
检查某个检查点的变量
使用 inspect_checkpoint 库快速检查某个检查点的变量。
from tensorflow.python.tools import inspect_checkpoint as chkp
# 打印所有 tensors
chkp.print_tensors_in_checkpoint_file("model/model.ckpt", tensor_name='', all_tensors=True)
# tensor_name: a
# [ 1. 1. 1.]
# tensor_name: b
# [ 2. 2. 2. 2. 2.]
# 打印 tensor a
chkp.print_tensors_in_checkpoint_file("model/model.ckpt", tensor_name='a', all_tensors=False)
# tensor_name: a
# [ 1. 1. 1.]
# 打印 tensor b
chkp.print_tensors_in_checkpoint_file("model/model.ckpt", tensor_name='b', all_tensors=False)
# tensor_name: b
# [ 2. 2. 2. 2. 2.]