tensorflow编程基础之共享变量

共享变量

背景:在某种情况下,一个模型需要使用其他模型创建的变量,两个模型一起训练。比如,对抗网络中的生成器模型与判别器模型。如果使用tf.Variable,将会生成一个新的变量,而我们需要的是原来的那个变量。这时就是通过引入get_variable方法,实现共享变量来解决这个问题。这种方法可以使用多套网络模型来训练一套权重。

使用get_variable获取变量

get_variable一般会配合variable_scope一起使用,以实现共享变量。variable_scope的意思是变量作用域。在某一作用域中的变量可以被设置成共享的方式,被其他网络模型使用。
get_variable函数的定义:
tf.get_variable(name,shape,initializer)
在TensorFlow里,使用get_variable生成的变量是以指定的name属性为唯一标识,并不是定义的变量名称。使用时一般通过name属性定位到具体变量,并将其共享到其他模型中。

在特定作用域下获取变量

我们已知使用get_variable创建两个同样名字的变量是行不通的,以下代码会出错:

var1 = tf.get_variable("firstvar",shape=[2],dtype=tf.float32)
var2 = tf.get_variable("firstvar",shape=[2],dtype=tf.float32)

如果真的想要那么做,可以使用variable_scope将它们隔开,代码如下。

import tensorflow as tf

#定义一个作用域test1
with tf.variable_scope("test1"):
    var1 = tf.get_variable("firstvar",shape=[2],dtype=tf.float32)
#定义一个作用域test2
with tf.variable_scope("test2"):
    var2 = tf.get_variable("firstvar",shape=[2],dtype=tf.float32)

print("var1:",var1.name)
print("var2:",var2.name)

结果如下:
var1: test1/firstvar:0
var2: test2/firstvar:0

共享变量功能的实现

使用作用域中的reuse参数来实现共享变量功能。
费了这么大劲来使用get_variable,目的其实是为了通过它实现共享变量的功能。
variable_scope里面有个reuse=True属性,表示使用已经定义过的变量。这时get_variable将不会再创建新的变量,而是去图(一个计算任务)中get_variable所创建过的变量中找name相同的变量。
实例:

#定义一个作用域test1
with tf.variable_scope("test1"):
    var1 = tf.get_variable("firstvar",shape=[2],dtype=tf.float32)
#定义一个作用域test2
with tf.variable_scope("test2"):
    var2 = tf.get_variable("firstvar",shape=[2],dtype=tf.float32)


with tf.variable_scope("test1",reuse=True):
    var3 = tf.get_variable("firstvar",shape=[2],dtype=tf.float32)
#定义一个作用域test2
with tf.variable_scope("test2",reuse=True):
    var4 = tf.get_variable("firstvar",shape=[2],dtype=tf.float32)

print("var1:",var1.name)
print("var2:",var2.name)
print("var3:",var3.name)
print("var4:",var4.name)

结果:
var1: test1/firstvar:0
var2: test2/firstvar:0
var3: test1/firstvar:0
var4: test2/firstvar:0
说明:var1和var3输出名字是一样的。var2和var4输出名字也是一样的,这就实现了共享变量。在实际应用中,可以把var1和var2放到一个网络模型中去训练,把var3和var4放到另一个网络模型中去训练,而两个模型的训练结果都会作用于一个模型的学习参数上。

作用域与操作符的受限范围

实例1:

import tensorflow as tf

with tf.variable_scope("scope1") as sp:
    var1 = tf.get_variable("v",[1])

with tf.variable_scope("scope2"):
    with tf.variable_scope(sp):
        var2 = tf.get_variable("v2",[1])

print("var1:",var1.name)
print("var2:",var2.name)

结果:
var1: scope1/v:0
var2: scope1/v2:0
分析:我们发现,var2并没有受到tf.variable_scope(“scope2”)的影响,这是因为我们采用variable_scope(sp),而sp是with tf.variable_scope(“scope1”) as sp,产生的。

实例2:
概述:变量会受到tf.variable_scope的影响,而操作符会同时受到tf.variable_scope和tf.name_scope的影响。

with tf.variable_scope("scope1") as sp:
    with tf.name_scope("bar"):
        var1 = tf.get_variable("v",[1])
        x = var1+1.0
print("var1:",var1.name)
print("x.op:",x.op.name)

结果:
var1: scope1/v:0
x.op: scope1/bar/add

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值