[Tensorflow] 变量和作用域


tf. get_variable()

why 存在?: 因为如果使用Variable的话每次都会新建变量, 但是我们希望一些变量可以重用, 所以就用到了get_variable().
功能: 它会去搜索变量名, 有就直接用, 没有就新建.
既然用到变量名了, 就涉及到了名字域的概念, 通过不同的域来区别变量名


作用域

tf.Variable(): 如果检测到命名冲突, 系统会自己处理

import tensorflow as tf

with tf.name_scope('scope1'):
    var1 = tf.get_variable(name='var1', shape=[1])
    var2 = tf.Variable(name='var2', initial_value=[1])
    var3 = tf.Variable(name='var2', initial_value=[1])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(var1.name, sess.run(var1))
    print(var2.name, sess.run(var2))
    print(var3.name, sess.run(var3))

'''
输出:
(u'var1:0', array([-0.10135889], dtype=float32))
(u'scope1/var2:0', array([1], dtype=int32))
(u'scope1/var2_1:0', array([1], dtype=int32))
'''

tf.get_variable(): 如果检测到命名冲突, 系统会报错

import tensorflow as tf

with tf.name_scope('scope1'):
    var1 = tf.get_variable(name='var1', shape=[1])
    var2 = tf.get_variable(name='var1', shape=[1])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(var1.name, sess.run(var1))
    print(var2.name, sess.run(var2))

'''
输出: 
ValueError: Variable var1 already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at:
  File "NameScope.py", line 4, in <module>
    var1 = tf.get_variable(name='var1', shape=[1])
'''

设置共享变量-方法一

import tensorflow as tf

with tf.variable_scope('scope1') as scope:
    var1 = tf.get_variable(name='var1', shape=[1])
    scope.reuse_variables()
    var2 = tf.get_variable(name='var1', shape=[1])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(var1.name, sess.run(var1))
    print(var2.name, sess.run(var2))

'''
输出:
(u'scope1/var1:0', array([ 1.12260759], dtype=float32))
(u'scope1/var1:0', array([ 1.12260759], dtype=float32))
'''

设置共享变量-方法二

import tensorflow as tf

with tf.variable_scope('foo') as foo_scope:
        var1 = tf.get_variable('v', [1])
with tf.variable_scope('foo', reuse=True):
    var2 = tf.get_variable('v')
assert var2 == var1

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(var1.name, sess.run(var1))
    print(var2.name, sess.run(var2))

'''
输出:
(u'foo/v:0', array([-1.11612296], dtype=float32))
(u'foo/v:0', array([-1.11612296], dtype=float32))
'''
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值