- 内容摘自《TensorFlow实战Google深度学习框架》 第二版
集合(collection)
有效整理TenserFlow程序中的资源也时计算图的一个重要功能。在一个计算图中,可以通过集合(collection) 来管理不同个类别的资源。 tf.add_to_collection
函数可以将资源加入一个或多个集合中,然后通过 tf.get_collection
获取一个集合里面的所有资源。 资源可以是张量、变量或运行tensorflow程序所需的队列资源。
- 以下是最常用的几个自动维护的集合:
集合名称 | 集合内容 | 使用场景 |
---|---|---|
tf.GraphKeys.VARIABLES | 所有变量 | 持久化模型 |
tf.GraphKeys.TRAINABLE_VARIABLES | 可学习的变量 | 模型训练,生成模型可视化内容 |
tf.GraphKeys.SUMMARIES | 日志生成的相关张量 | TensorFlow计算可视化 |
tf.GraphKeys.QUEUE_RUNNERS | 处理输入的QueueRunner | 输入处理 |
tf.GraphKeys.MOVING_AVERAGE_VARIABLES | 所有计算了滑动平均值的变量 | 计算了变量的滑动平均值 |
通过tf.global_variables()
函数可以拿到当前计算图上的所有变量
张量与变量的关系
变量的声明函数tf.Variable
是一个运算,其输出结果是一个张量。变量是一种特殊的张量。
变量的属性:
- name: TensorFlow的计算都可以通过计算图的模型来建立,图上每一个节点代表了一个计算,计算的结果就保存在张量之中。张量的命名以 “node:src_output” 的形式给出,例如"add:0" 就说明该张量是计算节点"add"输出的第一个结果。
- type: 与大部分程序语言类似,变量的类型是不可改变的。
- shape: 通过设置参数
validate_shape = False
可在程序运行中改变维度 (一般不会用); 通过张量的get_shape
函数可获取维度信息 免去CNN中间过程变量人工计算维度的麻烦。
w1 = tf.Variable(tf.random_normal([2,3], stddev=1), name = 'w1')
w2 = tf.Variable(tf.random_normal([2,2], stddev=1), name = 'w2')
tf.assign(w1, w2) # 报错
tf.assign(w1, w2, validate_shape = False) # 可被成功执行