tensorflow collection

tensorflowcollection提供一个全局的存储机制,不会受到变量名生存空间的影响。一处保存,到处可取。

接口介绍

#向collection中存数据
tf.Graph.add_to_collection(name, value)

#Stores value in the collection with the given name.
#Note that collections are not sets, so it is possible to add a value to a collection
#several times.
# 注意,一个‘name’下,可以存很多值; add_to_collection("haha", [a,b]),这种情况下
#tf.get_collection("haha")获得的是 [[a,b]], 并不是[a,b]
tf.add_to_collection(name, value)
#这个和上面函数功能上没有区别,区别是,这个函数是给默认图使用的
 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
#从collection中获取数据
tf.Graph.get_collection(name, scope=None)

Returns a list of values in the collection with the given name.

This is different from get_collection_ref() which always returns the actual
collection list if it exists in that it returns a new list each time it is called.

Args:

name: The key for the collection. For example, the GraphKeys class contains many
standard names for collections.
scope: (Optional.) If supplied, the resulting list is filtered to include only
items whose name attribute matches using re.match. Items without a name attribute
are never returned if a scope is supplied and the choice or re.match means that
a scope without special tokens filters by prefix.
#返回re.match(r"scope", item.name)匹配成功的item, re.match(从字符串的开始匹配一个模式)
Returns:

The list of values in the collection with the given name, or an empty list if no
value has been added to that collection. The list contains the values in the
order under which they were collected.

tf自己也维护一些collection,就像我们定义的所有summary op都会保存在name=tf.GraphKeys.SUMMARIES。这样,tf.get_collection(tf.GraphKeys.SUMMARIES)就会返回所有定义的summary op


tf.add_to_collection:把变量放入一个集合,把很多变量变成一个列表

tf.get_collection:从一个结合中取出全部变量,是一个列表

tf.add_n:把一个列表的东西都依次加起来

import tensorflow as tf;    
import numpy as np;    
import matplotlib.pyplot as plt;    
  
v1 = tf.get_variable(name='v1', shape=[1], initializer=tf.constant_initializer(0))  
tf.add_to_collection('loss', v1)  
v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(2))  
tf.add_to_collection('loss', v2)  
  
with tf.Session() as sess:  
    sess.run(tf.initialize_all_variables())  
    print tf.get_collection('loss')  
    print sess.run(tf.add_n(tf.get_collection('loss')))
也可以在声明变量的时候指定放到哪个collection中

w1 = tf.get_variable('w1', [self.n_features, n_l1], initializer=w_initializer, collections=c_names)
                b1 = tf.get_variable('b1', [1, n_l1], initializer=b_initializer, collections=c_names)
                l1 = tf.nn.relu(tf.matmul(self.s_, w1) + b1)



  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值