其一 变量赋值:
import tensorflow as tf
count_variable = tf.get_variable("count1",[1,2]) #使用 tf.get_variable()创建变量
zero_node = tf.constant([[2.,3.]])
assign_node = tf.assign(count_variable, zero_node) #将值放入变量
const_init_node = tf.constant_initializer(0.)
count_variable1 = tf.get_variable("count2", [], initializer=const_init_node)
init = tf.global_variables_initializer() #与16行呼应,将6和7关联
input_placeholder = tf.placeholder(tf.int32) #占位符
with tf.Session() as sess:
print(sess.run(assign_node))
print(sess.run(count_variable))
sess.run(init)
print(sess.run(count_variable1))
print(sess.run(input_placeholder, feed_dict={input_placeholder: 2})) #占位符喂数据
其二 共享变量
import tensorflow as tf
with tf.variable_scope('foo'):
v = tf.get_variable('v', [1], initializer=tf.constant_initializer(1.0))
with tf.variable_scope('foo', reuse=True):
v1 = tf.get_variable('v', [1])
print(v) # <tf.Variable 'foo/v:0' shape=(1,) dtype=float32_ref>
print(v1) # <tf.Variable 'foo/v:0' shape=(1,) dtype=float32_ref>
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(v1)) # [1.]
其三 恢复保存在模型ckpt中的数据,如下两个代码
1、保存数据(vv共享数据,count1 1*2维数据,count2 2*3维数据)
import tensorflow as tf
tf.reset_default_graph()
with tf.variable_scope('foo'):
v = tf.get_variable('vv', [1], initializer=tf.constant_initializer(1.0))
zero_node = tf.constant([[2.,3.]])
count_variable = tf.get_variable("count1",initializer=zero_node) #使用 tf.get_variable()创建变量,并将值放入变量
const_init_node = tf.constant_initializer(1)
count_variable1 = tf.get_variable("count2", [2,3], initializer=const_init_node)
init = tf.global_variables_initializer() #与13行呼应,将6和7关联
input_placeholder = tf.placeholder(tf.int32) #占位符
saver = tf.train.Saver()
with tf.Session() as sess:
print(sess.run(init))
print("count1:%s"%sess.run(count_variable))
print("count2 %s"%sess.run(count_variable1))
print(sess.run(input_placeholder, feed_dict={input_placeholder: 2})) #占位符喂数据
print(sess.run(v))
save_path = saver.save(sess, "./tmp/model1.ckpt") # ./表示当前目录
print("Model saved in path: %s" % save_path)
2、显示vv,count1 ,count2
import tensorflow as tf
tf.reset_default_graph()
with tf.variable_scope('foo'):
vv=tf.get_variable("vv",shape=[1])
count1=tf.get_variable("count1",shape=[1,2])
count2=tf.get_variable("count2",shape=[2,3])
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "./tmp/model1.ckpt")
print("vv:%s"%vv.eval())
print("count1:%s" % count1.eval())
print("count1:%s" % count2.eval())