TensorFlow collection的概念

  • 内容摘自《TensorFlow实战Google深度学习框架》 第二版

集合(collection)

有效整理TenserFlow程序中的资源也时计算图的一个重要功能。在一个计算图中,可以通过集合(collection) 来管理不同个类别的资源。 tf.add_to_collection 函数可以将资源加入一个或多个集合中,然后通过 tf.get_collection 获取一个集合里面的所有资源。 资源可以是张量、变量或运行tensorflow程序所需的队列资源。

  • 以下是最常用的几个自动维护的集合:
集合名称集合内容使用场景
tf.GraphKeys.VARIABLES所有变量持久化模型
tf.GraphKeys.TRAINABLE_VARIABLES可学习的变量模型训练,生成模型可视化内容
tf.GraphKeys.SUMMARIES日志生成的相关张量TensorFlow计算可视化
tf.GraphKeys.QUEUE_RUNNERS处理输入的QueueRunner输入处理
tf.GraphKeys.MOVING_AVERAGE_VARIABLES所有计算了滑动平均值的变量计算了变量的滑动平均值

通过tf.global_variables() 函数可以拿到当前计算图上的所有变量

张量与变量的关系

变量的声明函数tf.Variable 是一个运算,其输出结果是一个张量。变量是一种特殊的张量。

变量的属性:

  • name: TensorFlow的计算都可以通过计算图的模型来建立,图上每一个节点代表了一个计算,计算的结果就保存在张量之中。张量的命名以 “node:src_output” 的形式给出,例如"add:0" 就说明该张量是计算节点"add"输出的第一个结果。
  • type: 与大部分程序语言类似,变量的类型是不可改变的。
  • shape: 通过设置参数validate_shape = False 可在程序运行中改变维度 (一般不会用); 通过张量的get_shape 函数可获取维度信息 免去CNN中间过程变量人工计算维度的麻烦。
w1 = tf.Variable(tf.random_normal([2,3], stddev=1), name = 'w1')
w2 = tf.Variable(tf.random_normal([2,2], stddev=1), name = 'w2')
tf.assign(w1, w2)  # 报错
tf.assign(w1, w2, validate_shape = False) # 可被成功执行

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值