tf.get_collection 可以看做tf的核心函数
Aliases:Wrapper for Graph.get_collection()
using the default graph. Graph.get_collection()的包装器,使用了默认图。
Graph.get_collection()这个函数只要是个图就可以调用,而tf.get_collection是个核心函数,哪里都能用,但只针对默认图
tf.get_collection(
key,
scope=None
)
See tf.Graph.get_collection
for more details. 这里的get_collection是tf中Graph类中的get_collection函数
地址:https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/Graph#get_collection
Args:
key
: The key for the collection. For example, theGraphKeys
class contains many standard names for collections.
关键字:集合的键。例如,GraphKeys
class包含许多集合的标准名字,有关GraphKeys类参考这个网址:
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/GraphKeysscope
: (Optional.) If supplied, the resulting list is filtered to include only items whosename
attribute matches usingre.match
. Items without aname
attribute are never returned if a scope is supplied and the choice orre.match
means that ascope
without special tokens filters by prefix.
scope:(可选)如果提供,则筛选结果列表为仅包含 name 属性匹配 re.match 使用的项目.如果一个范围是提供的,并且选择或 re. match 意味着没有特殊的令牌过滤器的范围,则不会返回没有名称属性的项. 这个实在没看懂,也没找到用例
Returns:
The list of values in the collection with the given name
, or an empty list if no value has been added to that collection. The list contains the values in the order under which they were collected.
集合中具有给定 name 的值的列表,或者如果没有值已添加到该集合中,则为空列表.该列表包含按其收集顺序排列的值.
Eager Compatibility
Collections are not supported when eager execution is enabled.
-----------------------------------------------------------------------------------------------------------------------------
进入主题 Class GraphKeys
GraphKeys 转载网址https://www.deeplearn.me/2479.html
tf.GraphKeys
包含所有graph collection中的标准集合名,有点像 Python 里的 build-in fuction。
首先要了解graph collection是什么。
graph collection
在官方教程——图和会话中,介绍什么是 tf.Graph是这么说的:
tf.Graph
包含两类相关信息:
- 图结构。图的节点和边缘,指明了各个指令组合在一起的方式,但不规定它们的使用方式。图结构与汇编代码类似:检查图结构可以传达一些有用的信息,但它不包含源代码传达的的所有有用上下文。
- **图集合。**TensorFlow 提供了一种通用机制,以便在
tf.Graph
中存储元数据集合。tf.add_to_collection
函数允许您将对象列表与一个键相关联(其中tf.GraphKeys
定义了部分标准键),tf.get_collection
则允许您查询与键关联的所有对象。TensorFlow 库的许多组成部分会使用它:例如,当您创建tf.Variable
时,系统会默认将其添加到表示“全局变量(tf.global_variables
)”和“可训练变量tf.trainable_variables
)”的集合中。当您后续创建tf.train.Saver
或tf.train.Optimizer
时,这些集合中的变量将用作默认参数。
也就是说,在创建图的过程中,TensorFlow 的 Python 底层会自动用一些collection对 op 进行归类,方便之后的调用。这部分collection的名字被称为tf.GraphKeys
,可以用来获取不同类型的 op。当然,我们也可以自定义collection来收集 op。
常见 GraphKeys中的标准键集合
- GLOBAL_VARIABLES: 该 collection 默认加入所有的
Variable
对象,并且在分布式环境中共享。一般来说,TRAINABLE_VARIABLES
包含在MODEL_VARIABLES
中,MODEL_VARIABLES
包含在GLOBAL_VARIABLES
中。 - LOCAL_VARIABLES: 与
GLOBAL_VARIABLES
不同的是,它只包含本机器上的Variable
,即不能在分布式环境中共享。 - MODEL_VARIABLES: 顾名思义,模型中的变量,在构建模型中,所有用于正向传递的
Variable
都将添加到这里。 - TRAINALBEL_VARIABLES: 所有用于反向传递的
Variable
,即可训练(可以被 optimizer 优化,进行参数更新)的变量。 - SUMMARIES: 跟 Tensorboard 相关,这里的
Variable
都由tf.summary
建立并将用于可视化。 - QUEUE_RUNNERS: 队列信息,在 tf 优化中会遇到,正常情况下 io 速度还是相较于 gpu/cpu 慢很多 ,所以为了提升效率异步的方式去处理数据读写和运算 the
QueueRunner
objects that are used to produce input for a computation. - MOVING_AVERAGE_VARIABLES: the subset of
Variable
objects that will also keep moving averages. - REGULARIZATION_LOSSES: 正则化 loss regularization losses collected during graph construction.
The following standard keys are defined, but their collections are not automatically populated as many of the others are:
下面标准keys使已经定义了,但是他们的集合并没有像其他标准Keys一样自动构成
WEIGHTS
BIASES
ACTIVATIONS
下面通过一些例子来说明一下这个是如何使用的。
#tf.GraphKeys是一个集合,其中包含一些标准集合,GLOBAL_VARIABLES就是某一个标准集合的键,返回集合中变量和操作构成的列表
# -*- coding: utf-8 -*-
# @Time : 2019-04-21 17:59
# @Author : zhusimaji
# @File : adadfafaw.py
# @Software: PyCharm
import tensorflow as tf
sess=tf.Session()
a=tf.get_variable("a",[3,3,32,64],initializer=tf.random_normal_initializer())
b=tf.get_variable("b",[64],initializer=tf.random_normal_initializer())
#collections=None 等价于 collection=[tf.GraphKeys.GLOBAL_VARIABLES]
print("I am a:", a)
print("I am b:", b)
print("I am gv:", tf.GraphKeys.GLOBAL_VARIABLES)
gv= tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
#tf.get_collection(collection_name)返回某个 collection 的列表
print("I am gv:", gv)
for var in gv:
print("Iam var:",var)
print(var is a)
print(var.get_shape())
print("----------------")
输出结果如下所示:
I am a: <tf.Variable 'a:0' shape=(3, 3, 32, 64) dtype=float32_ref>
I am b: <tf.Variable 'b:0' shape=(64,) dtype=float32_ref>
I am gv: variables
I am gv: [<tf.Variable 'a:0' shape=(3, 3, 32, 64) dtype=float32_ref>, <tf.Variable 'b:0' shape=(64,) dtype=float32_ref>]
Iam var: <tf.Variable 'a:0' shape=(3, 3, 32, 64) dtype=float32_ref>
True
(3, 3, 32, 64)
----------------
Iam var: <tf.Variable 'b:0' shape=(64,) dtype=float32_ref>
False
(64,)
----------------
由上面的定义可以看到这些变量的定义相当于一个集合,只要你在计算图中定义了相关的变量,这些集合默认就会记录你当前所有定义的变量,后续你也可以通过这种方式去获取当前所有变量,甚者你可以在 summary 的时候全部记录下来使用。