tensorflow 变量用法及恢复模型中的数据演示

其一 变量赋值:

import tensorflow as tf
count_variable = tf.get_variable("count1",[1,2]) #使用 tf.get_variable()创建变量
zero_node = tf.constant([[2.,3.]])
assign_node = tf.assign(count_variable, zero_node) #将值放入变量

const_init_node = tf.constant_initializer(0.)
count_variable1 = tf.get_variable("count2", [], initializer=const_init_node)
init = tf.global_variables_initializer() #与16行呼应,将6和7关联

input_placeholder = tf.placeholder(tf.int32) #占位符

with tf.Session() as sess:
    print(sess.run(assign_node))
    print(sess.run(count_variable))

    sess.run(init)
    print(sess.run(count_variable1))

    print(sess.run(input_placeholder, feed_dict={input_placeholder: 2})) #占位符喂数据

其二 共享变量

import tensorflow as tf

with tf.variable_scope('foo'):
    v = tf.get_variable('v', [1], initializer=tf.constant_initializer(1.0))

with tf.variable_scope('foo', reuse=True):
    v1 = tf.get_variable('v', [1])

print(v)  # <tf.Variable 'foo/v:0' shape=(1,) dtype=float32_ref>
print(v1)  # <tf.Variable 'foo/v:0' shape=(1,) dtype=float32_ref>
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(v1))  # [1.]

 

其三 恢复保存在模型ckpt中的数据,如下两个代码

1、保存数据(vv共享数据,count1 1*2维数据,count2 2*3维数据)

import tensorflow as tf
tf.reset_default_graph()
with tf.variable_scope('foo'):
    v = tf.get_variable('vv', [1], initializer=tf.constant_initializer(1.0))

zero_node = tf.constant([[2.,3.]])
count_variable = tf.get_variable("count1",initializer=zero_node)  #使用 tf.get_variable()创建变量,并将值放入变量

const_init_node = tf.constant_initializer(1)
count_variable1 = tf.get_variable("count2", [2,3], initializer=const_init_node)
init = tf.global_variables_initializer() #与13行呼应,将6和7关联

input_placeholder = tf.placeholder(tf.int32) #占位符
saver = tf.train.Saver()
with tf.Session() as sess:
    print(sess.run(init))
    print("count1:%s"%sess.run(count_variable))
    print("count2 %s"%sess.run(count_variable1))
    print(sess.run(input_placeholder, feed_dict={input_placeholder: 2})) #占位符喂数据
    print(sess.run(v))

    save_path = saver.save(sess, "./tmp/model1.ckpt")  # ./表示当前目录
    print("Model saved in path: %s" % save_path)

2、显示vv,count1 ,count2 

import tensorflow as tf
tf.reset_default_graph()

with tf.variable_scope('foo'):
  vv=tf.get_variable("vv",shape=[1])
count1=tf.get_variable("count1",shape=[1,2])
count2=tf.get_variable("count2",shape=[2,3])
saver = tf.train.Saver()
with tf.Session() as sess:
  saver.restore(sess, "./tmp/model1.ckpt")
  print("vv:%s"%vv.eval())
  print("count1:%s" % count1.eval())
  print("count1:%s" % count2.eval())

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值