tensorflow graph 中获取variable operation

tensorflow提供了一些列的方法获取和变量计算图中的variable和operation。

获取单个operation/variable

可以通过如下两个方法获取图中的相关variable和operation:
1. tf.Graph.get_tensor_by_name(tensor_name)
2. tf.Graph.get_operation_by_name(op_name)

批量获取

批量获取的方式主要有如下几种:
1. graph.node

import tensorflow as tf

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
v3 = tf.get_variable("v3", shape=[4], initializer = tf.zeros_initializer)

inc_v1 = tf.assign(v1,v1+1,name='inc_v1')
dec_v2 = tf.assign(v2,v2-1,name='dec_v2')
dec_v3 = tf.assign(v3,v3-2,name='dec_v3')

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:

  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  dec_v3.op.run()

  for n in tf.get_default_graph().as_graph_def().node:
    print n
输出:
name: "v1/Initializer/zeros"
op: "Const"
attr {
  key: "_class"
  value {
    list {
      s: "loc:@v1"
    }
  }
}
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "value"
  value {
    tensor {
      dtype: DT_FLOAT
      tensor_shape {
        dim {
          size: 3
        }
      }
      float_val: 0.0
    }
  }
}

该方法列出了每个graph中每个node的详细信息。
2. graph.get_operations()

  for op in tf.get_default_graph().get_operations():
    print op.name
    print op.values()
输出:
name:v1/Initializer/zeros
value:(<tf.Tensor 'v1/Initializer/zeros:0' shape=(3,) dtype=float32>,)
name:v1
value:(<tf.Tensor 'v1:0' shape=(3,) dtype=float32_ref>,)

op.valuses()将返回该op对应的tensor对象,可以进一步获取tensor的name,shape等信息。
3. tf.all_variables()

import tensorflow as tf

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
v3 = tf.get_variable("v3", shape=[4], initializer = tf.zeros_initializer)

inc_v1 = tf.assign(v1,v1+1,name='inc_v1')
dec_v2 = tf.assign(v2,v2-1,name='dec_v2')
dec_v3 = tf.assign(v3,v3-2,name='dec_v3')

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:

  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  dec_v3.op.run()

  for variable in tf.all_variables():
    print variable
    print variable.name
输出:
<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref>
v1:0
<tf.Variable 'v2:0' shape=(5,) dtype=float32_ref>
v2:0
<tf.Variable 'v3:0' shape=(4,) dtype=float32_ref>
v3:0

该方法返回默认计算图中所有的variable()对象
4. tf.get_collection(collection_key)

    for variable in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        print variable
输出:
<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref>
<tf.Variable 'v2:0' shape=(5,) dtype=float32_ref>
<tf.Variable 'v3:0' shape=(4,) dtype=float32_ref>

该方法根据key返回相应collection中的对象。

tensorflow中预定义了一些grapykClass GraphKeys

Standard names to use for graph collections.
The standard library uses various well-known names to collect and retrieve values associated with a graph. For example, the tf.Optimizer subclasses default to optimizing the variables collected under tf.GraphKeys.TRAINABLE_VARIABLES if none is specified, but it is also possible to pass an explicit list of variables.
The following standard keys are defined:
  ● GLOBAL_VARIABLES: the default collection of Variable objects, shared across distributed environment (model variables are subset of these). See tf.global_variables for more details. Commonly, all TRAINABLE_VARIABLES variables will be in MODEL_VARIABLES, and all MODEL_VARIABLES variables will be in GLOBAL_VARIABLES.
  ● LOCAL_VARIABLES: the subset of Variable objects that are local to each machine. Usually used for temporarily variables, like counters. Note: use tf.contrib.framework.local_variable to add to this collection.
  ● MODEL_VARIABLES: the subset of Variable objects that are used in the model for inference (feed forward). Note: use tf.contrib.framework.model_variable to add to this collection.
  ● TRAINABLE_VARIABLES: the subset of Variable objects that will be trained by an optimizer. Seetf.trainable_variables for more details.
  ● SUMMARIES: the summary Tensor objects that have been created in the graph. See tf.summary.merge_all for more details.
  ● QUEUE_RUNNERS: the QueueRunner objects that are used to produce input for a computation. Seetf.train.start_queue_runners for more details.
  ● MOVING_AVERAGE_VARIABLES: the subset of Variable objects that will also keep moving averages. Seetf.moving_average_variables for more details.
  ● REGULARIZATION_LOSSES: regularization losses collected during graph construction.
The following standard keys are defined, but their collections are not automatically populated as many of the others are:
  ● WEIGHTS
  ● BIASES
  ● ACTIVATIONS 

除了预定义的collecion,tensorflow还支持自定义collection方法–tf.add_collection(key,value),tf.get_collection(key)。tf的collecion提供了一种全局的存储机制,不收命名空间影响。代码如下:

 import tensorflow as tf

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
v3 = tf.get_variable("v3", shape=[4], initializer = tf.zeros_initializer)

inc_v1 = tf.assign(v1,v1+1,name='inc_v1')
dec_v2 = tf.assign(v2,v2-1,name='dec_v2')
dec_v3 = tf.assign(v3,v3-2,name='dec_v3')

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:

  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  dec_v3.op.run()

  # Add variable into 
  tf.add_to_collection('test',v1)
  tf.add_to_collection('test',v2)
    tf.add_to_collection('test',inc_v1)

  for element in tf.get_collection('test'):
    print element
输出:
<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref>
<tf.Variable 'v2:0' shape=(5,) dtype=float32_ref>
Tensor("inc_v1:0", shape=(3,), dtype=float32_ref)

tf还提供了获取graph中所有collection的方法:

   for key in tf.get_default_graph().get_all_collection_keys():
      print 'key:'+key
      for element in tf.get_collection(key):
          print element
输出
   key:variables
<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref>
<tf.Variable 'v2:0' shape=(5,) dtype=float32_ref>
<tf.Variable 'v3:0' shape=(4,) dtype=float32_ref>
key:trainable_variables
<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref>
<tf.Variable 'v2:0' shape=(5,) dtype=float32_ref>
<tf.Variable 'v3:0' shape=(4,) dtype=float32_ref>
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值