tensorflow collection

https://blog.csdn.net/shenxiaolu1984/article/details/52815641

tensorflow用集合colletion组织不同类别的对象。tf.GraphKeys中包含了所有默认集合的名称。

collection提供了一种“零存整取”的思路:在任意位置,任意层次都可以创造对象,存入相应collection中;创造完成后,统一从一个collection中取出一类变量,施加相应操作。

例如,tf.Optimizer只优化tf.GraphKeys.TRAINABLE_VARIABLES中的变量。

本文介绍几个常用集合 
- Variable集合:模型参数 
- Summary集合:监测 
- 自定义集合

Variable
Variable被收集在名为tf.GraphKeys.VARIABLES的colletion中

定义
Tensorflow使用Variable类表达、更新、存储模型参数。

Variable是在可变更的,具有保持性的内存句柄,存储着Tensor。必须使用Tensor进行初始化。

k = tf.Variable(tf.random_normal([]), name='k')
1
创建的Variable被添加到默认的collection中。

初始化
在整个session运行之前,图中的全部Variable必须被初始化。

sess = tf.Session()
init = tf.initialize_all_variables() 
sess.run(init)
1
2
3
在执行完初始化之后,Variable中的值生成完毕,不会再变化。

特别强调:Variable的值在sess.run(init)之后就确定了;Tensor的值要在sess.run(x)之后才确定。

获取
和Tensor, Operation一样,Variable也是全局的。 
可以通过tf.all_variables()查看所有tf.GraphKeys.VARIABLES中的对象:

# example for y = k*x
x = tf.constant(1.0, shape=[])      # 0D tensor
k = tf.Variable(tf.constant(0.5, shape=[]) )
y = tf.mul(x, k)
v = tf.all_variables()
1
2
3
4
5
也可以用通用方法直接访问collection:

v = tf.get_collection(tf.GraphKeys.VARIABLES)
1
各类Variable
另外,tensorflow还维护另外几个collection:

函数    集合名    意义
tf.all_variables()    VARIABLES    存储和读取checkpoints时,使用其中所有变量
tf.trainable_variables()    TRAINABLE_VARIABLES    训练时,更新其中所有变量
tf.moving_average_variables()    MOVING_AVERAGE_VARIABLES    ExponentialMovingAverage对象会生成此类变量
tf.local_variables()    LOCAL_VARIABLES    在all_variables()之外,需要用tf.init_local_variables()初始化
tf.model_variables()    MODEL_VARIABLES    
Summary
Summary被收集在名为tf.GraphKeys.SUMMARIES的colletion中

定义
Summary是对网络中Tensor取值进行监测的一种Operation。这些操作在图中是“外围”操作,不影响数据流本身。

用例
我们模仿常见的训练过程,创建一个最简单的用例。

# 迭代的计数器
global_step = tf.Variable(0, trainable=False)
# 迭代的+1操作
increment_op = tf.assign_add(global_step, tf.constant(1))
# 实例应用中,+1操作往往在`tf.train.Optimizer.apply_gradients`内部完成。

# 创建一个根据计数器衰减的Tensor
lr = tf.train.exponential_decay(0.1, global_step, decay_steps=1, decay_rate=0.9, staircase=False)

# 把Tensor添加到观测中
tf.scalar_summary('learning_rate', lr)

# 并获取所有监测的操作`sum_opts`
sum_ops = tf.merge_all_summaries()

# 初始化sess
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)  # 在这里global_step被赋初值

# 指定监测结果输出目录
summary_writer = tf.train.SummaryWriter('/tmp/log/', sess.graph)

# 启动迭代
for step in range(0, 10):
    s_val = sess.run(sum_ops)    # 获取serialized监测结果:bytes类型的字符串
    summary_writer.add_summary(s_val, global_step=step)   # 写入文件
    sess.run(increment_op)     # 计数器+1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
调用tf.scalar_summary系列函数时,就会向默认的collection中添加一个Operation。

再次回顾“零存整取”原则:创建网络的各个层次都可以添加监测;在添加完所有监测,初始化sess之前,统一用tf.merge_all_summaries获取。

查看
SummaryWriter文件中存储的是序列化的结果,需要借助TensorBoard才能查看。

在命令行中运行tensorboard,传入存储SummaryWriter文件的目录:

tensorboard --logdir /tmp/log
1
完成后会提示:

You can navigate to http://127.0.1.1:6006
1
可以直接使用服务器本地浏览器访问这个地址(本机6006端口),或者使用远程浏览器访问服务器ip地址的6006端口。

自定义
除了默认的集合,我们也可以自己创造collection组织对象。网络损失就是一类适宜对象。

tensorflow中的Loss提供了许多创建损失Tensor的方式。

x1 = tf.constant(1.0)
l1 = tf.nn.l2_loss(x1)

x2 = tf.constant([2.5, -0.3])
l2 = tf.nn.l2_loss(x2)
1
2
3
4
5
创建损失不会自动添加到集合中,需要手工指定一个collection:

tf.add_to_collection("losses", l1)
tf.add_to_collection("losses", l2)
1
2
创建完成后,可以统一获取所有损失,losses是个Tensor类型的list:

losses = tf.get_collection('losses')
1
另一种常见操作把所有损失累加起来得到一个Tensor:

loss_total = tf.add_n(losses)
1
执行操作可以得到损失取值:

sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
losses_val = sess.run(losses)
loss_total_val = sess.run(loss_total)
1
2
3
4
5
实际上,如果使用TF-Slim包的losses系列函数创建损失,会自动添加到名为”losses”的collection中。
--------------------- 
作者:shenxiaolu1984 
来源:CSDN 
原文:https://blog.csdn.net/shenxiaolu1984/article/details/52815641 
版权声明:本文为博主原创文章,转载请附上博文链接!

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值