tf local_variables_initializer 和global_variables_initializer的区别

作者:batman
链接:https://www.zhihu.com/question/61834943/answer/828562407
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
 

首先,global和local variable只是逻辑上的区分!技术上是一样,都是记录在collection里面,也就是你可以自行定义你的变量是global或local,只要在collection里面声明好。如:

e = tf.Variable(6, name='var_e', collections=[tf.GraphKeys.LOCAL_VARIABLES])

其次,你可以自行定义自己的collection,可以不是local global trainable 或model中的任何一个。

my_variable10 = tf.get_variable("var10", dtype=tf.int32, initializer=tf.constant(3), collections=["my_collection"], trainable=True)    

第三,local 和 global是tf在逻辑上对变量进行的定义和区分。声明的变量默认是global。之所以tf自定义了四种variable collection,是一方面这几个collection最常用,二是针对这四种collection为了方便系统统一初始化或其他操作,tf内部针对它们做了不少工作,也就是https://blog.csdn.net/shenxiaolu1984/article/details/52815641中提到"零存整取"。我们常见的tf自带操作有:

sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

需要进行某种操作,只要找到加入了某个collection里面的变量进行对应的操作就好了,非常方便。

 

下面从stackoverflow找到一个例子(https://stackoverflow.com/a/52124613),很全面,以供参考:),很全面,以供参考:

tf.__version__                          # => '1.14.0'                                                                                                                          

# initializing using a Tensor                                                                                                                                                                                                               
my_variable01 = tf.get_variable("var01", dtype=tf.int32, initializer=tf.constant([23, 42]))                                                                                                                                                 
# initializing using a convenient initializer                                                                                                                                                                                               
my_variable02 = tf.get_variable("var02", shape=[1, 2, 3], dtype=tf.int32, initializer=tf.zeros_initializer)                                                                                                                                 

my_variable03 = tf.get_variable("var03", dtype=tf.int32, initializer=tf.constant([1, 2]), trainable=None)                                                                                                                                   
my_variable04 = tf.get_variable("var04", dtype=tf.int32, initializer=tf.constant([3, 4]), trainable=False)                                                                                                                                  
my_variable05 = tf.get_variable("var05", shape=[1, 2, 3], dtype=tf.int32, initializer=tf.ones_initializer, trainable=True)                                                                                                                  

my_variable06 = tf.get_variable("var06", dtype=tf.int32, initializer=tf.constant([5, 6]), collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=None)                                                                                       
my_variable07 = tf.get_variable("var07", dtype=tf.int32, initializer=tf.constant([7, 8]), collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=True)                                                                                       

my_variable08 = tf.get_variable("var08", dtype=tf.int32, initializer=tf.constant(1), collections=[tf.GraphKeys.MODEL_VARIABLES], trainable=None)                                                                                            
my_variable09 = tf.get_variable("var09", dtype=tf.int32, initializer=tf.constant(2), collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.MODEL_VARIABLES, tf.GraphKeys.TRAINABLE_VARIABLES, "my_collectio
n"])                                                                                                                                                                                                                                        
my_variable10 = tf.get_variable("var10", dtype=tf.int32, initializer=tf.constant(3), collections=["my_collection"], trainable=True)                                                                                                         

print([var.name for var in tf.global_variables()] ) # => ['var01:0', 'var02:0', 'var03:0', 'var04:0', 'var05:0', 'var09:0']
print([var.name for var in tf.local_variables()] ) # => ['var06:0', 'var07:0', 'var09:0']
print([var.name for var in tf.trainable_variables()] ) # => ['var01:0', 'var02:0', 'var05:0', 'var07:0', 'var09:0', 'var10:0']
print([var.name for var in tf.model_variables()] ) # => ['var08:0', 'var09:0']
print([var.name for var in tf.get_collection("trainable_variables")]  )# => ['var01:0', 'var02:0', 'var05:0', 'var07:0', 'var09:0', 'var10:0']
print([var.name for var in tf.get_collection("my_collection")] ) # => ['var09:0', 'var10:0']
     

亲测结果如下:

['var01:0', 'var02:0', 'var03:0', 'var04:0', 'var05:0', 'var09:0']
['var06:0', 'var07:0', 'var09:0']
['var01:0', 'var02:0', 'var03:0', 'var05:0', 'var06:0', 'var07:0', 'var08:0', 'var09:0', 'var10:0']
['var08:0', 'var09:0']
['var01:0', 'var02:0', 'var03:0', 'var05:0', 'var06:0', 'var07:0', 'var08:0', 'var09:0', 'var10:0']
['var09:0', 'var10:0']

 

我遇到这个问题最早是来源于tf.train.match_filenames_once()这个函数,要加入sess.run(tf.local_variables_initializer())才不会出错,进入这个函数的定义:

@tf_export(
    "io.match_filenames_once",
    v1=["io.match_filenames_once", "train.match_filenames_once"])
@deprecation.deprecated_endpoints("train.match_filenames_once")
def match_filenames_once(pattern, name=None):
  """Save the list of files matching pattern, so it is only computed once.

  NOTE: The order of the files returned can be non-deterministic.

  Args:
    pattern: A file pattern (glob), or 1D tensor of file patterns.
    name: A name for the operations (optional).

  Returns:
    A variable that is initialized to the list of files matching the pattern(s).
  """
  with ops.name_scope(name, "matching_filenames", [pattern]) as name:
    return vs.variable(
        name=name, initial_value=io_ops.matching_files(pattern),
        trainable=False, validate_shape=False,
        collections=[ops.GraphKeys.LOCAL_VARIABLES])

看最后一行看到问题的核心了,就是这个函数把变量加入到了local variable里面,所以没初始化local variable是无法运行的。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值