声明:
- 参考Tensorflow官方文档
- tensorflow当前版本1.1
- 更新:现在tensorflow官网有了中文教程,很方便学习了
tf.train.Saver()
tf.train.Saver()
是一个类,提供了变量、模型(也称图Graph)的保存和恢复模型方法。
TensorFlow是通过构造Graph的方式进行深度学习,任何操作(如卷积、池化等)都需要operator,保存和恢复操作也不例外。在tf.train.Saver()
类初始化时,用于保存和恢复的save
和restore
operator会被加入Graph。所以,下列类初始化操作应在搭建Graph时完成。
saver = tf.train.Saver()
TensorFlow的保存和恢复分为两种:
- 保存和恢复变量
- 保存和恢复模型
保存变量
TensorFlow会讲变量保存在二进制checkpoint文件中,这类文件会将变量名称映射到张量值。
下面是保存变量的例子:
- 创建变量
- 初始化变量
- 实例化
tf.train.Saver()
- 创建Session并保存
import tensorflow as tf
# Create some variables.
v1 = tf.get_variable("v1_name", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2_name", shape=[5], initializer = tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train