把tf变量加到某一个集合中
方法一:系统默认
w1=tf.get_variable('w1',[3,5])
w1=tf.variable([[3,5]])
用以上任何一个语句创建变量时,系统会默认将其添加到表示“全局变量(tf.global_variables
)”和“可训练变量(tf.trainable_variables
)”的集合中。
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
variables = tf.get_collection(tf.GraphKeys.VARIABLES)
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
以上三个语句,都可以得到变量w1的信息
方法二:添加scope
with tf.variable_scope('fc_network'):
w=tf.get_variable('w',[3,5])
b=tf.get_variable('b',[1])
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='fc_network')
for i in variables:
print(i)
<tf.Variable 'fc_network/w:0' shape=(3, 5) dtype=float32_ref>
<tf.Variable 'fc_network/b:0' shape=(1,) dtype=float32_ref>
这时的集合只包含了scope='fc_network'的全局变量
方法三:添加key
cnames=['fc', tf.GraphKeys.GLOBAL_VARIABLES]
layer1=slim.fully_connected(x, n_l1, variables_collections=cnames)
b3=tf.get_variable('b3',[3,5],collections=cnames)
v=tf.get_collection('fc')#就可以得到slim.fully_connected默认生成的变量w,b以及b3变量
要注意的是collections必须是列表或者元组的形式,不能是字符串‘fc’,而且cnames里的tf.GraphKeys.GLOBAL_VARIABLES不能省略,否则对应变量不会被放在全局变量那个大集合里
方法四:直接使用tf.add_to_collection
tf.add_to_collection(name, value) 用来把一个value放入名称是‘name’的集合
例如:
v1 = tf.get_variable(name='v1', shape=[1], initializer=tf.constant_initializer(1))
tf.add_to_collection('output', v1)
collection一个很好的应用:
在强化学习DQN中,有一个用来评估的网络q_eval和一个固定不变,用来计算target(为q_eval提供训练的label)的网络q_target,它们两个的网络结构一样,但参数不同。q_eval是不断被训练的,而q_target没有训练过程,它只是过一定时间,就将q_eval的参数复制过来,用于更新。所以需要把q_eval和q_target的变量集合分别提取出来,再赋值。具体代码如下:
t_params = tf.get_collection('target_net_params')
e_params = tf.get_collection('eval_net_params')
replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]
推荐博文:https://blog.csdn.net/shenxiaolu1984/article/details/52815641