TensorFlow之共享变量

一、Variable的用法

这里我们把TensorFlow简称为tf。

当我们利用tf声明一个变量时,变量中有一个参数是“name”,这是变量的唯一标识,当我们不指定“name”参数值时,系统自动给当前变量设定“name”值为Variable:0。因此当我们用语句:

var1 = tf.Variable(1.0,name='firstvar')

定义出了一个var1,当我们把这个语句写两遍时,则会在内存中生成两个var1,这是因为他们的name会不一样(后面我们将会看到他们各自的name变成了什么),而对于Session()来说,最后一次定义的var1才是生效的。
下面我们来看代码检验的结果:

import tensorflow as tf

var1 = tf.Variable(1.0,name='firstvar')
print("var1:",var1.name)
var1 = tf.Variable(2.0,name='firstvar')
print("var1:",var1.name)
var2 = tf.Variable(3.0)
print("var2:",var2.name)
var2 = tf.Variable(4.0)
print("var2:",var2.name)
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("var1=",var1.eval())
    print("var2=",var2.eval())

下面是此代码运行结果:

var1: firstvar:0
var1: firstvar_1:0
var2: Variable:0
var2: Variable_1:0
var1= 2.0
var2= 4.0

从上面的结果不难看出,两个var1的name确实不一样。而session()只认最后一次定义的var值。

二、使用get_variable获取变量

TensorFlow中可以通过调用get_variable()方法来获取变量,使用此方法来生成变量是以属性name作为唯一标识的,当我们时候用get_variable方法定义变量时,name必须是已经已经存在的,借用上面的例子,我们来测试一下:

import tensorflow as tf

var1 = tf.Variable(1.0,name='firstvar')
print("var1:",var1.name)
var1 = tf.Variable(2.0,name='firstvar')
print("var1:",var1.name)
var2 = tf.Variable(3.0)
print("var2:",var2.name)
var2 = tf.Variable(4.0)
print("var2:",var2.name)
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("var1=",var1.eval())
    print("var2=",var2.eval())
    
get_var1 = tf.get_variable("firstvar13",[1],initializer=tf.constant_initializer(0.3))
print("get_var1:",get_var1.name)

get_var1 = tf.get_variable("firstvar14",[1],initializer=tf.constant_initializer(0.4))
print("get_var1:",get_var1.name)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("get_var1=",get_var1.eval())

主要看后面几行,因为每运行一次,系统就会防止重复而添加数字表示差别,故而在get_Variable中的参数应当是本次运行的name,运行结果如下:

var1: firstvar_13:0
var1: firstvar_14:0
var2: Variable_12:0
var2: Variable_13:0
var1= 2.0
var2= 4.0
get_var1: firstvar13:0
get_var1: firstvar14:0
get_var1= [1.5064439]

特别注意,若果我们的get_var1和get_var2用相同的名字进行定义时,必须要定义所用于,否则会发生冲突。

三、特定作用域下获取变量

我们通过get_variable定义变量时,如若想要使用相同的名字进行定义,则需要给这样的语句进行范围限定
代码如下:

import tensorflow as tf

var1 = tf.Variable(1.0,name='firstvar')
print("var1:",var1.name)
var1 = tf.Variable(2.0,name='firstvar')
print("var1:",var1.name)
var2 = tf.Variable(3.0)
print("var2:",var2.name)
var2 = tf.Variable(4.0)
print("var2:",var2.name)
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("var1=",var1.eval())
    print("var2=",var2.eval())
    
with tf.variable_scope("test1",):    
    get_var1 = tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.3))

with tf.variable_scope("test2",):
    get_var1 = tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.4))

print("get_var1:",get_var1.name)
print("get_var1:",get_var1.name)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("get_var1=",get_var1.eval())

运行结果如下:

var1: firstvar:0
var1: firstvar_1:0
var2: Variable:0
var2: Variable_1:0
var1= 2.0
var2= 4.0
get_var1: test2/firstvar:0
get_var1: test2/firstvar:0
get_var1= [0.4]

四、共享变量功能的实现

之所以使用get_Variable方法定义变量,就是为了通过该方法实现变量的共享
在上面定义域的方法variable_scope中,有个布尔属性为reuse,当该属性设置为Ture时,表示get_variable不会创建新的变量,而是使用已经定义过的变量,并在计算任务中找name值相同的已经创建过的变量。
这里我们的scope域还可以嵌套使用,下面的代码将进行展示:

import tensorflow as tf

with tf.variable_scope("test1",):
    #之前运行多次,故而内存中产生了多个firstvar*
    #之所以这里name设为firstvar2,是因为test1/firstvar0-6已经在内存中存在,而我的reuse默认为fale,故而无法创建同名新变量,不更名将会报错
    var1 = tf.get_variable("firstvar7",shape=[2],dtype=tf.float32)
    #这里我们尝试一下嵌套的scope用法
    with tf.variable_scope("test2",):
        var2 = tf.get_variable("firstvar7",shape=[2],dtype=tf.float32)

print("var1:",var1.name)
print("var2:",var2.name)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

with tf.variable_scope("test1",reuse=True):
    #与上面对比,因为此时的reuse设置为True,所以我这里的名字是firstvar2已经存在的变量,不会创建新变量
    var3=tf.get_variable("firstvar7",shape=[2],dtype = tf.float32)
    with tf.variable_scope("test2",reuse=True):
        var4 = tf.get_variable("firstvar7",shape=[2],dtype = tf.float32)
        

print("var3:",var3.name)
print("Var4:",var4.name)

运行结果如下:

var1: test1/firstvar7:0
var2: test1/test2/firstvar7:0
var3: test1/firstvar7:0
Var4: test1/test2/firstvar7:0

这里可以看出,var1和var3的变量是一样的(也是同一个域),这就可以表明var1和var3共用了一个变量。
当我们实际操作时,将var1和var2放到一个网络模型中,而var3和var4放到另一个网络模型中,而这两个网络模型的结果都会作用到一个模型的学习参数上

五、内存问题简述

我使用的编译器是Anaconda工具包里面的Spyder工具,在此编译器中,利用get_variable方法创建变量是,会检查一个计算任务中,是否已经创建过该变量,如果已经创建过,并且本次的reuse并没有设置为True(共享),则会报错。
对此,我们可以直接使用tf.reset_default_graph(),语句,将一个计算任务里的变量清空即可。

六、关于scope的初始化变量

当我们在variable_scope方法中使用属性initializer=tf.constant_initializer(x)时,则该域内的所有变量都会被赋值为x,即使是内里嵌套的域内的变量也会被初始化为x,但是如果我们利用get_variable方法时,在该方法中用同样的属性,设置了一个其他的初始值y,则当前变量会被设为y。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值