作者:batman
链接:https://www.zhihu.com/question/61834943/answer/828562407
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
首先,global和local variable只是逻辑上的区分!技术上是一样,都是记录在collection里面,也就是你可以自行定义你的变量是global或local,只要在collection里面声明好。如:
e = tf.Variable(6, name='var_e', collections=[tf.GraphKeys.LOCAL_VARIABLES])
其次,你可以自行定义自己的collection,可以不是local global trainable 或model中的任何一个。
my_variable10 = tf.get_variable("var10", dtype=tf.int32, initializer=tf.constant(3), collections=["my_collection"], trainable=True)
第三,local 和 global是tf在逻辑上对变量进行的定义和区分。声明的变量默认是global。之所以tf自定义了四种variable collection,是一方面这几个collection最常用,二是针对这四种collection为了方便系统统一初始化或其他操作,tf内部针对它们做了不少工作,也就是https://blog.csdn.net/shenxiaolu1984/article/details/52815641中提到"零存整取"。我们常见的tf自带操作有:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
需要进行某种操作,只要找到加入了某个collection里面的变量进行对应的操作就好了,非常方便。
下面从stackoverflow找到一个例子(https://stackoverflow.com/a/52124613),很全面,以供参考:),很全面,以供参考:
tf.__version__ # => '1.14.0'
# initializing using a Tensor
my_variable01 = tf.get_variable("var01", dtype=tf.int32, initializer=tf.constant([23, 42]))
# initializing using a convenient initializer
my_variable02 = tf.get_variable("var02", shape=[1, 2, 3], dtype=tf.int32, initializer=tf.zeros_initializer)
my_variable03 = tf.get_variable("var03", dtype=tf.int32, initializer=tf.constant([1, 2]), trainable=None)
my_variable04 = tf.get_variable("var04", dtype=tf.int32, initializer=tf.constant([3, 4]), trainable=False)
my_variable05 = tf.get_variable("var05", shape=[1, 2, 3], dtype=tf.int32, initializer=tf.ones_initializer, trainable=True)
my_variable06 = tf.get_variable("var06", dtype=tf.int32, initializer=tf.constant([5, 6]), collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=None)
my_variable07 = tf.get_variable("var07", dtype=tf.int32, initializer=tf.constant([7, 8]), collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=True)
my_variable08 = tf.get_variable("var08", dtype=tf.int32, initializer=tf.constant(1), collections=[tf.GraphKeys.MODEL_VARIABLES], trainable=None)
my_variable09 = tf.get_variable("var09", dtype=tf.int32, initializer=tf.constant(2), collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.MODEL_VARIABLES, tf.GraphKeys.TRAINABLE_VARIABLES, "my_collectio
n"])
my_variable10 = tf.get_variable("var10", dtype=tf.int32, initializer=tf.constant(3), collections=["my_collection"], trainable=True)
print([var.name for var in tf.global_variables()] ) # => ['var01:0', 'var02:0', 'var03:0', 'var04:0', 'var05:0', 'var09:0']
print([var.name for var in tf.local_variables()] ) # => ['var06:0', 'var07:0', 'var09:0']
print([var.name for var in tf.trainable_variables()] ) # => ['var01:0', 'var02:0', 'var05:0', 'var07:0', 'var09:0', 'var10:0']
print([var.name for var in tf.model_variables()] ) # => ['var08:0', 'var09:0']
print([var.name for var in tf.get_collection("trainable_variables")] )# => ['var01:0', 'var02:0', 'var05:0', 'var07:0', 'var09:0', 'var10:0']
print([var.name for var in tf.get_collection("my_collection")] ) # => ['var09:0', 'var10:0']
亲测结果如下:
['var01:0', 'var02:0', 'var03:0', 'var04:0', 'var05:0', 'var09:0']
['var06:0', 'var07:0', 'var09:0']
['var01:0', 'var02:0', 'var03:0', 'var05:0', 'var06:0', 'var07:0', 'var08:0', 'var09:0', 'var10:0']
['var08:0', 'var09:0']
['var01:0', 'var02:0', 'var03:0', 'var05:0', 'var06:0', 'var07:0', 'var08:0', 'var09:0', 'var10:0']
['var09:0', 'var10:0']
我遇到这个问题最早是来源于tf.train.match_filenames_once()这个函数,要加入sess.run(tf.local_variables_initializer())才不会出错,进入这个函数的定义:
@tf_export(
"io.match_filenames_once",
v1=["io.match_filenames_once", "train.match_filenames_once"])
@deprecation.deprecated_endpoints("train.match_filenames_once")
def match_filenames_once(pattern, name=None):
"""Save the list of files matching pattern, so it is only computed once.
NOTE: The order of the files returned can be non-deterministic.
Args:
pattern: A file pattern (glob), or 1D tensor of file patterns.
name: A name for the operations (optional).
Returns:
A variable that is initialized to the list of files matching the pattern(s).
"""
with ops.name_scope(name, "matching_filenames", [pattern]) as name:
return vs.variable(
name=name, initial_value=io_ops.matching_files(pattern),
trainable=False, validate_shape=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES])
看最后一行看到问题的核心了,就是这个函数把变量加入到了local variable里面,所以没初始化local variable是无法运行的。