tf. get_variable()
why 存在?: 因为如果使用Variable的话每次都会新建变量, 但是我们希望一些变量可以重用, 所以就用到了get_variable().
功能: 它会去搜索变量名, 有就直接用, 没有就新建.
既然用到变量名了, 就涉及到了名字域的概念, 通过不同的域来区别变量名
作用域
tf.Variable(): 如果检测到命名冲突, 系统会自己处理
import tensorflow as tf
with tf.name_scope('scope1'):
var1 = tf.get_variable(name='var1', shape=[1])
var2 = tf.Variable(name='var2', initial_value=[1])
var3 = tf.Variable(name='var2', initial_value=[1])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(var1.name, sess.run(var1))
print(var2.name, sess.run(var2))
print(var3.name, sess.run(var3))
'''
输出:
(u'var1:0', array([-0.10135889], dtype=float32))
(u'scope1/var2:0', array([1], dtype=int32))
(u'scope1/var2_1:0', array([1], dtype=int32))
'''
tf.get_variable(): 如果检测到命名冲突, 系统会报错
import tensorflow as tf
with tf.name_scope('scope1'):
var1 = tf.get_variable(name='var1', shape=[1])
var2 = tf.get_variable(name='var1', shape=[1])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(var1.name, sess.run(var1))
print(var2.name, sess.run(var2))
'''
输出:
ValueError: Variable var1 already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at:
File "NameScope.py", line 4, in <module>
var1 = tf.get_variable(name='var1', shape=[1])
'''
设置共享变量-方法一
import tensorflow as tf
with tf.variable_scope('scope1') as scope:
var1 = tf.get_variable(name='var1', shape=[1])
scope.reuse_variables()
var2 = tf.get_variable(name='var1', shape=[1])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(var1.name, sess.run(var1))
print(var2.name, sess.run(var2))
'''
输出:
(u'scope1/var1:0', array([ 1.12260759], dtype=float32))
(u'scope1/var1:0', array([ 1.12260759], dtype=float32))
'''
设置共享变量-方法二
import tensorflow as tf
with tf.variable_scope('foo') as foo_scope:
var1 = tf.get_variable('v', [1])
with tf.variable_scope('foo', reuse=True):
var2 = tf.get_variable('v')
assert var2 == var1
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(var1.name, sess.run(var1))
print(var2.name, sess.run(var2))
'''
输出:
(u'foo/v:0', array([-1.11612296], dtype=float32))
(u'foo/v:0', array([-1.11612296], dtype=float32))
'''