TensorFlow创建变量

一、使用tf.Variable函数创建变量

tf.Variable(initial_value=None,trainable=True,collections=None,validate_shape=True,caching_device=None,name=None,variable_def=None,dtype=None,expected_shape=None,import_scope=None)

函数功能:创建一个新的变量,变量的值是initial_value,创建的变量会被添加到[GraphKeys.GLOBAL_VARIABLES]默认的计算图列表中,如果trainable被设置为True,这个变量还会被添加到GraphKeys.TRAINABLE_VARIABLES计算图的集合中。

参数:

initial_value:默认值是None,张量或者是一个python对象可以转成张量,这个initial_value是初始化变量的值。它必须有一个特殊的shape,除非validate_shape设置为False。

trainable:默认的是True,变量还会被添加到GraphKeys.TRAINABLE_VARIABLES计算图集合中。

collections:变量会被添加到这个集合中,默认的集合是[GraphKeys.GLOBAL_VARIABLES]。

validate_shape:如果是False,允许这个变量被初始化一个不知道shape。默认的是True,这个initial_value的shape必须是知道的。

name:变量的名字。

dypte:变量的类型,小数的默认是float32,整数默认是int32。

    a = tf.Variable(initial_value=[1,2,3],name="a")
    init = tf.initialize_all_variables()
    sess = tf.Session()
    sess.run(init)
    print(a.eval(session=sess))
    #[1 2 3]
    print(a.dtype)
    #<dtype: 'int32_ref'>
二、使用tf.get_variable函数创建变量

tf.get_variable(name,shape=None,dtype=None,initializer=None,regularizer=None,trainable=True,collections=None,
caching_device=None,partitioner=None,validate_shape=True,use_resource=None,custom_getter=None)

函数功能:根据变量的名称来获取变量或者创建变量。

参数:

name:变量的名称(必选)。

shape:变量的shape。

dtype:变量的数据类型。

initializer:变量的初始化值。

1、根据变量的名称创建变量

    b = tf.get_variable(name="b", initializer=[1., 2., 3.])
    init = tf.initialize_all_variables()
    sess = tf.Session()
    sess.run(init)
    print(b.eval(session=sess))
    #[ 1.  2.  3.]
    print(b.dtype)
    #<dtype: 'float32_ref'>

使用tf.get_variable创建变量的时候,如果不指定name,会报TypeError: get_variable() missing 1 required positional argument: 'name'

2、根据变量的名称获取变量
    with tf.variable_scope("f"):
        #初始化一个变量名称为c的变量
        c = tf.get_variable(name="c",shape=[3],initializer=tf.constant_initializer([1,2,3]))

    with tf.variable_scope("f",reuse=True):
        d = tf.get_variable(name="c",shape=[3])
        sess = tf.Session()
        init = tf.initialize_all_variables()
        sess.run(init)
        print(d.eval(session=sess))
        #[ 1.  2.  3.]
        print(c.eval(session=sess))
        #[ 1.  2.  3.]
        print(d == c)
        #True
在使用tf.get_variable()根据变量的名称来获取已经生成变量的时候,需要通过tf.variable_scope函数来生成一个上下文管理器,并明确指定在这个上下文管理器中。获取变量值的时候,需要将上下文管理器中的reuse设置为True,才能直接获取已经声明的变量,如果不设置reuse会报错。需要注意的是,如果变量名在上下文管理器中已经存在,在获取的时候,如果不将reuse设置为True则会报错。同理,如果上下文管理器中不存在变量名,在使用reuse=True获取变量值的时候,也会报错。
三、tf.variable_scope的嵌套

    with tf.variable_scope("a"):#默认是False
        #查看上下文管理器中的reuse的值
        print(tf.get_variable_scope().reuse) #False
        with tf.variable_scope("b",reuse=True):
            print(tf.get_variable_scope().reuse) #True
            #如果reuse是默认的则保持和上一层的reuse值一样
            with tf.variable_scope("c"):
                print(tf.get_variable_scope().reuse) #True
        print(tf.get_variable_scope().reuse) #False
四、上下文管理器与变量名

    #没有上下文管理器
    a = tf.get_variable(name="a",shape=[2],initializer=tf.constant_initializer([1,2]))
    print(a.name) #a:0,a就是变量名
    #声明上下文管理器
    with tf.variable_scope("f"):
        b = tf.get_variable(name="b",shape=[2],initializer=tf.constant_initializer([1,2]))
        print(b.name) #f/b:0,f代表的是上下文管理器的名称,b代表的是变量的名称
        #嵌套上下文管理器
        with tf.variable_scope("g"):
            c = tf.get_variable(name="c",shape=[2],initializer=tf.constant_initializer([1,2]))
            print(c.name)#f/g/c:0
通过上下文管理器和变量名来获取变量

    #通过带上下文管理器名称和变量名来获取变量
    with tf.variable_scope("",reuse=True):
        d = tf.get_variable(name="f/b")
        print(d == b)  #True
        e = tf.get_variable(name="f/g/c")
        print(e == c)  #True






  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

修炼之路

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值