TensorFlow(tf)中可用tf.variable_scope函数生产上下文管理器,在该上下文管理器的命名空间内定义的变量名称都会带上这个命名空间名作为前缀,例如 with tf.variable_scope("s1"): 下定义的变量变量名前缀都是"s1/".
从外,tf.train.Saver类可以用于保存指定session下的变量计算图和变量值.可参考以下示例程序:
import tensorflow as tf
tf.reset_default_graph()
with tf.variable_scope("s1"):
v1 = tf.get_variable(name = "v1", initializer=tf.constant(1.0))
v2 = tf.get_variable(name = "v2", initializer=tf.constant(2.0))
with tf.variable_scope("s2"):
v3 = tf.get_variable(name = "v1", initializer=tf.constant(3.0))
v4 = tf.get_variable(name = "v2", initializer=tf.constant(4.0))
result = v1 + v4
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
saver = tf.train.Saver()
saver.save(sess, "./model/model1.ckpt")
print(result.eval())
#输出:
#5.0
然后就可以在另外的程序中加载ckpt文件中存储的变量值(若两程序中相应变量的变量名一样,则无需另外操作):
# -*- coding: utf-8 -*-
import tensorflow as tf
tf.reset_default_graph()
with tf.variable_scope("s1"):
v1 = tf.get_variable(name = "v1", initializer=tf.constant(1.0))
with tf.variable_scope("s2"):
v4 = tf.get_variable(name = "v2", initializer=tf.constant(4.0))
result = v1 + v4
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, "./model/model1.ckpt")
print(result.eval())
#输出:
#INFO:tensorflow:Restoring parameters from ./model/model1.ckpt
#5.0
若是两个程序中相应变量的变量名不一样,则需要使用一个字典来指定变量名对应关系:
# -*- coding: utf-8 -*-
import tensorflow as tf
tf.reset_default_graph()
v1 = tf.get_variable(name = "other-v1", initializer=tf.constant(.0))
v2 = tf.get_variable(name = "other-v2", initializer=tf.constant(.0))
result = v1 + v2
with tf.Session() as sess:
saver = tf.train.Saver({"s1/v1": v1, "s2/v2": v2})
saver.restore(sess, "./model/model1.ckpt")
print(result.eval())
#输出:
#INFO:tensorflow:Restoring parameters from ./model/model1.ckpt
#5.0