tf.get_collection:从一个结合中取出全部变量,是一个列表
tensorflow用集合
colletion
组织不同类别的对象。
tf.GraphKeys
中包含了所有默认集合的名称。collection提供了一种“零存整取”的思路:在任意位置,任意层次都可以创造对象,存入相应collection
中;创造完成后,统一从一个collection
中取出一类变量,施加相应操作。例如,tf.Optimizer
只优化tf.GraphKeys.TRAINABLE_VARIABLES
中的变量。
Variable
被收集在名为tf.GraphKeys.VARIABLES
的colletion
中
定义
Tensorflow使用Variable
类表达、更新、存储模型参数。
k = tf.Variable(tf.random_normal([]), name='k')
创建的
Variable
被添加到默认的
collection
中.
获取
和
Tensor
,
Operation
一样,
Variable
也是全局的。
可以通过tf.all_variables()查看所有
tf.GraphKeys.VARIABLES
中的对象:
也可以用通用方法直接访问
collection
:
with tf.variable_scope('scope'): v1 = tf.get_variable('var', [1]) with tf.variable_scope('scope2'): v2 = tf.get_variable('var', [1])v = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
v = tf.all_variables()#这两种方式都可以
print(v)#[<tf.Variable 'scope/var:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'scope/scope2/var:0' shape=(1,) dtype=float32_ref>]
各类Variable
tensorflow维护多个collection
:
函数 集合名 意义 tf.all_variables() VARIABLES 存储和读取checkpoints时,使用其中所有变量 tf.trainable_variables() TRAINABLE_VARIABLES 训练时,更新其中所有变量 tf.moving_average_variables() MOVING_AVERAGE_VARIABLES ExponentialMovingAverage
对象会生成此类变量tf.local_variables() LOCAL_VARIABLES 在all_variables()
之外,需要用tf.init_local_variables()初始化 tf.model_variables() MODEL_VARIABLES
操作
tf.add_to_collection:把变量放入一个集合,把很多变量变成一个列表tf.get_collection:从一个集合中取出全部变量,是一个列表
tf.add_n:把一个列表的东西都依次加起来
import tensorflow as tf;
import numpy as np;
import matplotlib.pyplot as plt;
v1 = tf.get_variable(name='v1', shape=[1], initializer=tf.constant_initializer(0))
tf.add_to_collection('loss', v1)
v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(2))
tf.add_to_collection('loss', v2)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print (tf.get_collection('loss') ) #[<tf.Variable 'v1:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'v2:0' shape=(1,) dtype=float32_ref>]
print (sess.run(tf.add_n(tf.get_collection('loss'))) )#[ 2.]
v = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print(v)#[<tf.Variable 'v1:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'v2:0' shape=(1,) dtype=float32_ref>]