tensorflow:tf.Graph

本文详细介绍了TensorFlow中的Graph类,包括获取默认图、命名空间、变量作用域、集合管理和控制依赖等关键方法。Graph代表数据流图,包含运算和张量。通过tf.Graph.as_default设置默认图,tf.variable_scope实现变量复用,get_all_collection_keys和get_collection用于管理集合。控制依赖用于指定操作的执行顺序,确保某些操作在其他操作之后执行。
摘要由CSDN通过智能技术生成

Class Graph


A TensorFlow computation, represented as a dataflow graph.

一个Graph 包含很多tf.Operation 对象( represent units of computation;)和tf.Tensor 对象 (represent the units of data that flow between operations)

该类下的方法

tf.get_default_graph

- 获取默认的图

定义一个Operation就会添加新的操作到默认图:

c = tf.constant(4.0)
assert c.graph is tf.get_default_graph()

tf.Graph.as_default

创立新的默认图。

  • 注:这个类对于计算图来说不是线程安全的。
g = tf.Graph()
with g.as_default():
  # Define operations and tensors in `g`.
  c = tf.constant(30.0)
  assert c.graph is g

一个Graph 实例可以包括任意的”collections”。比如 tf.Variable 在创建一个图的时候放在这个 collection ( tf.GraphKeys.GLOBAL_VARIABLES) ,其他collection可以通过声明别的名字。collections相当对图中的类似的计算单元做了打包

name_scope

name_scope(name)

作用:Returns a context manager that creates hierarchical names for operations.

#操作单元在流程图上的操作名称
with tf.Graph().as_default() as g:
  c = tf.constant(5.0, name="c")
  assert c.op.name == "c"
  c_1 = tf.constant(6.0, name="c")
  assert c_1.op.name == "c_1"

#nested被声明为图g的默认scope
  # Creates a scope called "nested"
  with g.name_scope("nested") as scope:
    nested_c = tf.constant(10.0, name="c")
    assert nested_c.op.name == "nested/c"

    #图g创建scope "inner",被放在默认的scope“nested”内
    with g.name_scope("inner"):
      nested_inner_c = tf.constant(20.0, name="c")
      assert nested_inner_c.op.name == "nested/inner/c"

    # 名称为inner的scope已经存在,自动起别名 "inner_1",同样放在“nested”内
    with
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
import time import tensorflow.compat.v1 as tf tf.disable_v2_behavior() from tensorflow.examples.tutorials.mnist import input_data import mnist_inference import mnist_train tf.compat.v1.reset_default_graph() EVAL_INTERVAL_SECS = 10 def evaluate(mnist): with tf.Graph().as_default() as g: #定义输入与输出的格式 x = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} #直接调用封装好的函数来计算前向传播的结果 y = mnist_inference.inference(x, None) #计算正确率 correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32)) #通过变量重命名的方式加载模型 variable_averages = tf.train.ExponentialMovingAverage(0.99) variable_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variable_to_restore) #每隔10秒调用一次计算正确率的过程以检测训练过程中正确率的变化 while True: with tf.compat.v1.Session() as sess: ckpt = tf.train.get_checkpoint_state(minist_train.MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: #load the model saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] accuracy_score = sess.run(accuracy, feed_dict=validate_feed) print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score)) else: print('No checkpoint file found') return time.sleep(EVAL_INTERVAL_SECS) def main(argv=None): mnist = input_data.read_data_sets(r"D:\Anaconda123\Lib\site-packages\tensorboard\mnist", one_hot=True) evaluate(mnist) if __name__ == '__main__': tf.compat.v1.app.run()对代码进行改进
05-26
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值