Tensorflow中的共享变量机制小结

转自:https://cloud.tencent.com/developer/article/1092432

今天说一下tensorflow的变量共享机制,首先为什么会有变量共享机制? 这个还是要扯一下生成对抗网络GAN,我们知道GAN由两个网络组成,一个是生成器网络G,一个是判别器网络D。G的任务是由输入的隐变量z生成一张图像G(z)出来,D的任务是区分G(z)和训练数据中的真实的图像(real images)。所以这里D的输入就有2个,但是这两个输入是共享D网络的参数的,简单说,也就是权重和偏置。而TensorFlow的变量共享机制,正好可以解决这个问题。但是我现在不能确定,TF的这个机制是不是因为GAN的提出才有的,还是本身就存在。

所以变量共享的目的就是为了在对网络第二次使用的时候,可以使用同一套模型参数。TF中是由Variable_scope来实现的,下面我通过几个栗子,彻底弄明白到底该怎么使用,以及使用中会出现的错误。栗子来源于文档,然后我写了不同的情况,希望能帮到你。

# - * - coding:utf-8 - * -
import tensorflow as tf
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


def fc_variable():
    v1 = tf.Variable(
        initial_value=tf.random_normal(
            shape=[2, 3], mean=0., stddev=1.),
        dtype=tf.float32,
        name='variable_1')
    print v1
    print "- v1 - * " * 5
    return v1

"""
<tf.Variable 'variable_1:0' shape=(2, 3) dtype=float32_ref>
- v1 - * - v1 - * - v1 - * - v1 - * - v1 - * 
"""

def variable_value(variables):
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # 如果没有这句会报错,所以tf在调用变量之前主要
        # 先初始化
        """
        tensorflow.python.framework.errors_impl.
        FailedPreconditionError: Attempting to use
         uninitialized value variable_1
        """
        print '- * - value: - * - ' * 3
        print sess.run(variables)
        """
        [[ 0.00556329  0.20311342 -0.79569227]
         [ 0.1700473   0.9499892  -0.46801034]]
        """


def fc_variable_scope():
    with tf.variable_scope("foo"):
        v = tf.get_variable("v", [1])
        print v.name
        w = tf.get_variable("w", [1])
        print w.name

    with tf.variable_scope("foo", reuse=True):
        v1 = tf.get_variable("v")
        print v1.name

"""
foo/v:0
foo/w:0
foo/v:0
"""
# 解释:
# 这里说明v1和v的相同的,还有这里用的是
# get_variable定义的变量,这个和Variable
# 定义变量的区别是,如果变量存在get_variable
# 会获得他的值,如果不存在则创建变量


def fc_variable_scope_v2():
    with tf.variable_scope("foo"):
        v = tf.get_variable("v", [1])
        print v.name
        w = tf.get_variable("w", [1])
        print w.name

    with tf.variable_scope("foo", reuse=False):
        v1 = tf.get_variable("v")
        print v1.name


"""
ValueError: Variable foo/v already exists, disallowed. 
Did you mean to set reuse=True in VarScope? Originally
 defined at:
"""
# 解释:
# 当reuse为False的时候由于v1在'fool'这个scope里面,
# 所以和v的name是一样的,而reuse为False,变量命名就起了冲突。


def fc_variable_scope_v3():
    with tf.variable_scope("foo"):
        v = tf.get_variable("v", [1])
        print v.name
        w = tf.get_variable("w", [1])
        print w.name

    with tf.variable_scope("foo", reuse=True):
        v1 = tf.get_variable("u", [1])
        print v1.name


"""
ValueError: Variable foo/u does not exist, 
or was not created with tf.get_variable().
 Did you mean to set reuse=None in VarScope?
"""
# 解释:
# 当reuse为True时时候,而这里定义了新变量u,
# 之前不存在,这样也无法reuse。


def fc_variable_scope_v4():
    with tf.variable_scope("foo"):
        v = tf.get_variable("v", [1])
        print v.name
        w = tf.get_variable("w", [1])
        print w.name

    with tf.variable_scope("foo", reuse=False):
        v1 = tf.get_variable("u")
        print v1.name

"""
ValueError: Shape of a new variable (foo/u)
 must be fully defined, but instead was <unknown>.

"""
# 解释:
# 这里reuse为Flase,但是定义新变量的时候,
# 必须define fully变量,也就是要指定变量
# 的shape或者初始值等。


def fc_variable_scope_v5():
    with tf.variable_scope("foo"):
        v = tf.get_variable("v", [1])
        print dir(v)
        print v.name
        w = tf.get_variable("w", [1])
        print w.name

    with tf.variable_scope("foo", reuse=False):
        v1 = tf.get_variable("u", [1])
        print v1.name


"""
foo/v:0
foo/w:0
foo/u:0
"""
# 这样就没错了


def fc_variable_scope_v6():
    with tf.variable_scope("foo"):
        v1 = tf.Variable(tf.random_normal(
            shape=[2, 3], mean=0., stddev=1.),
            dtype=tf.float32, name='v1')
        print v1.name
        v2 = tf.get_variable("v2", [1])
        print v2.name

    with tf.variable_scope("foo", reuse=True):
        v3 = tf.get_variable('v2')
        print v3.name
        v4 = tf.get_variable('v1')
        print v4.name


"""
foo/v1:0
foo/v2:0
foo/v2:0

ValueError: Variable foo/v1 does not exist, or
 was not created with tf.get_variable(). Did 
 you mean to set reuse=None in VarScope?

"""

# 解释:
# 这里虽然reuse为True,但是v1是由Variable定义的,
# 不能被get。


def compare_name_and_variable_scope():
    with tf.name_scope("hello") as ns:
        arr1 = tf.get_variable(
            "arr1", shape=[2, 10], dtype=tf.float32)
        print (arr1.name)

    print " - * -" * 5
    with tf.variable_scope("hello") as vs:
        arr1 = tf.get_variable(
            "arr1", shape=[2, 10], dtype=tf.float32)
        print (arr1.name)

"""
arr1:0
 - * - - * - - * - - * - - * -
hello/arr1:0
"""
#解释:
# 这里除了name_scope和variable_scope不同,
# 其他都相同,但是从他们的name,也能看出来区别了。

if __name__ == "__main__":
    fc_variable_scope_v6()
    # # 需要测试那个函数,直接写在这里。

简单总结一下,今天的内容主要是变量定义的两种方法,Variable个get_variable,还有变量的范围以及reuse是什么鬼。通过几个栗子,应该明白了。


  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值