集合
tensorflow用集合colletion
组织不同类别的对象。tf.GraphKeys
中包含了所有默认集合的名称。
collection
提供了一种“零存整取”的思路:在任意位置,任意层次都可以创造对象,存入相应collection
中;创造完成后,统一从一个collection
中取出一类变量,施加相应操作。
例如,
tf.Optimizer
只优化tf.GraphKeys.TRAINABLE_VARIABLES
中的变量。
本文介绍几个常用集合
- Variable
集合:模型参数
- Summary
集合:监测
- 自定义集合
Variable
Variable
被收集在名为tf.GraphKeys.VARIABLES
的colletion
中
定义
Tensorflow使用Variable
类表达、更新、存储模型参数。
Variable
是在可变更的,具有保持性的内存句柄,存储着Tensor
。必须使用Tensor
进行初始化。
k = tf.Variable(tf.random_normal([]), name='k')
创建的Variable
被添加到默认的collection
中。
初始化
在整个session
运行之前,图中的全部Variable
必须被初始化。
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
在执行完初始化之后,Variable
中的值生成完毕,不会再变化。
特别强调:Variable
的值在sess.run(init)之后就确