Variables

Saving and Restoring Variables

tf.train.Saver

tf.train.Saver(var_lise=None)

tf.get_variable()/tf.variable_scope

tf.get_variable(name, shape=None, initializer=None)
tf.variable_scope(name_or_scope, reuse=None)

a = tf.get_variable('a', shape=[2,5], initializer=tf.truncated_normal_initializer())
a = tf.Variable(tf.truncated_normal(shape=[2,5]), name='a')

以上两个定义等价。
tf.get_variable()与tf.Variable()最大区别在于参数name,前者是必填参数,后者是可选参数。

tf.variable_scope()可以控制tf.get_variable()的功能,当reuse=True生成上下文管理器时,tf.get_variable()根据name的值直接获取已创建的变量,如果变量不存在(没有name那个值的变量)则会报错。当reuse=None或reuse=False创建上下文管理器,tf.get_variable创建新的变量,如果有同名变量(name一样),则会报错。

with tf.variable_scope('layer1'):
    w = tf.get_variable('w', shape=[2,5], initializer=tf.truncated_normal_initializer())
# w的name为'layer1/w',而下面创建的w1的name也为
# 'layer1/w',所以会报错。
with tf.variable_scope('layer1'):
    w1 = tf.get_variable('w', shape=[2,5])

with tf.variable_scope('layer1', reuse=True):
    w2 = tf.get_variable('w', [2,5])
    print(w2 == w)    # => True
# w2和w指向的是同一个东西,即'layer1/w'.

with tf.variable_scope('layer1', reuse=True):
    w3 = tf.get_variable('b', [2,5])
# 由于命名空间中没有'layer1/b',且reuse=True,tf.get_variable是获取非创建新变量,故报错。

with tf.variable_scope('layer1', reuse=True):
    w4 = tf.get_variable('w', [2])
# 虽然命名空间中有'layer1/w',但是shape不一样,故报错。

with tf.variable_scope('layer1', reuse=True):
    w5 = tf.get_variable('w', [2, 5], initializer=tf.constant_intializer())
# 虽然此时创建的initializer不一样,但不会报错,并且此时的initializer没有意义,依然是之前w的initializer.

tf.variable_scope()嵌套使用时,reuse参数的取值见如下代码,

with tf.variable_scope('layer1', reuse=True):
    print(tf.get_variable_scope().reuse)    # => True
    with tf.variable_scope('layer11'):
        print(tf.get_variable_scope().reuse)
        # 不指定reuse,内层的reuse取值会应用外一层的,此时输出为True。若外层reuse=FALSE,此时输出False。

tf.variable_scope()生成的上下文管理器去创建Tensorflow的命名空间还有一个用法,见下代码。

with tf.variable_scope("", reuse=True):
    w9 = tf.get_variable('layer1/w2', [2,5])
    # 创建一个空的命名空间,w9可以用来获得其他命名空间中的变量'layer1/w2'.

tf.Variable()

tf.Variable(initial_value=None, trainable=True, name=None)

a = tf.Variable(...,trainable=False,...)

参数trainable=False表示不将变量a加入到tf.GrapsKeys.TRAINABLE_VARIABLES集合中,即变量a不会出现在tf.trainable_variables()函数返回的变量列表中。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值