tensorflow中通过共享 变量作用域(variable_scope)
来实现共享变量 ,节约变量存储空间 。
TensorFlow用于变量管理的函数主要有两个:
tf. get_variable()
用于创建或获取变量的值tf.variable_scope()
用于生成上下文管理器,创建命名空间,命名空间可以嵌套。
函数**tf.get_variable()**既可以创建变量,也可以获取变量。用函数tf.variable.scope()中的参数reuse来控制,分两种情况进行说明:
- 设置
reuse=False
时,函数get_variable()表示创建变量
with tf.variable_scope("foo",reuse=False):
v=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
在tf.variable_scope()函数中,设置reuse=False时,在其命名空间"foo"中执行函数get_variable()时,表示创建变量"v",若在该命名空间中已经有了变量"v",则在创建时会报错。
import tensorflow as tf
with tf.variable_scope("foo"):
v=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
v1=tf.get_variable("v",[1])
# ValueError: Variable foo/v already exists, disallowed.
- 设置
reuse=True
时,函数get_variable()
表示获取变量
设置 reuse=True 可以再次调用 该共享变量作用域
import tensorflow as tf
with tf.variable_scope("foo"):
v=tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
with tf.variable_scope("foo",reuse=True):
v1=tf.get_variable("v",[1])
print(v1==v) #结果为:True
在tf.variable_scope()函数中,设置reuse=True时,在其命名空间"foo"中执行函数get_variable()时,表示获取变量"v"。若在该命名空间中还没有该变量,则在获取时会报错,如下面的例子
import tensorflow as tf
with tf.variable_scope("foo",reuse=True):
v1=tf.get_variable("v",[1])
# ValueError: Variable foo/v does not exist, or was not created with tf.get_variable()
TensorFlow通过tf. get_variable()和tf.variable_scope()两个函数,可以创建多个并列的或嵌套的命名空间,用于存储神经网络中的各层的权重、偏置、学习率、滑动平均衰减率、正则化系数等参数值,神经网络不同层的参数可放置在不同的命名空间中。同时,变量重用检错和读取不存在变量检错两种机制保证了数据存放的安全性。
一次性对variable_scope进行reuse的两种简便方法
- 使用
tf.Variable_scope(..., reuse=tf.AUTO_REUSE)
- 通过
from tensorflow.python.ops import variable_scope as vs
来导入操作。
# -*- coding: utf-8 -*-
import tensorflow as tf
def func(...):
with tf.variable_scope(name_or_scope='', reuse=tf.AUTO_REUSE): ### 改动部分 ###
方法2中,main函数要设置reuse=(_!=0)
# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.python.ops import variable_scope as vs ### 改动部分 ###
def func(..., reuse=False): ### 改动部分 ###
if reuse: ### 改动部分 ###
vs.get_variable_scope().reuse_variables() ### 改动部分 ###
return output
def main():
with tf.Graph().as_default():
pass
for _ in xrange(5):
output = func(..., reuse=(_!=0)) ### 改动部分 ###
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
pass
_output = sess.run(output, feed_dict=...)
pass
if __name__ == "__main__":
main()