Tensorflow核心工作模式非常简单:
- 定义数据流图(计算图)
- 运行数据流图
如果数据流图不存在,那么肯定是无法运行的。因此我们在工作时只需要关注上述工作流的一部分,有助于更加周密的组织自己的代码,明确工作方向。
用tensorflow构建一个基本的数据流图
import tensorflow as tf # Build our graph nodes, starting from the inputs a = tf.constant(5, name="input_a") b = tf.constant(3, name="input_b") c = tf.mul(a,b, name="mul_c") d = tf.add(a,b, name="add_d") e = tf.add(c,d, name="add_e") # Open up a TensorFlow Session sess = tf.Session() # Execute our output node, using our Session sess.run(e) # Open a TensorFlow SummaryWriter to write our graph to disk writer = tf.train.SummaryWriter('./my_graph', sess.graph) # Close our SummaryWriter and Session objects writer.close() sess.close() # 通过下面代码在终端打开TensorBoard 可以查看数据流图: # $ tensorboard --logdir='./my_graph'
1 . Tensorflow在运行时会产生一个默认的计算图(graph),如上面生成的数据流图,我们在一般情况下都是用的是默认图,程序中的每一个计算(tf.constant(), tf.mul(), tf.add()等)都是计算图上的一个node(节点)。我们可以手工指定(tf.Graph(),后面会讲到)。
2 . Session对象在运行时负责对数据流图进行监督,并且是运行数据流图的主要接口。上述代码将Session对象赋给了变量sess, 以便后期能够对其进行访问。
-
Tensorflow 的 Graph 对象
下面来研究tnesorflow 的 Graph 对象,学习如何创建更多的数据流图,以及如何让多个数据流图协同工作。创建 graph 对象的构造方法非常简单,不需要传入任何参数。
import tensorflow as tf #创建一个新的数据流图 g = tf.graph() #利用Graph.as_default()方法访问其上下文管理器,为其添加op with g.as_default(): #一些op添加进Graph对象g中 a = tf.mul(1, 2) ....... #不在 with语句块中,下面的op将放置在默认图中 In_default_graph = tf.sub(3, 4)
这里就解释到了上文出现的手工指定数据流图。为什么上面的例子不需要指定graph就能添加op,原因是:为了方便起见,当tensorflow库被加载时,它会自动创建一个Graph对象,并将其作为默认的数据流图。因此,在Graph.as_default()上下文管理器之外定义的任何op, tensor对象都会自动放置在默认的数据流图中。
#如果希望获得默认的数据流图的句柄,可使用: default_graph = tf.get_default_graph()
在大多数Tensorflow程序中,只使用默认图就足够了。然而,如果需要定义多个相互之间不存在依赖关系的模型,则创建多个Graph对象十分有用。
""" 当需要在单个文件中定义多个数据流图时,最佳方式是不使用默认图或者为默认图分配句柄 """ #创建新的数据流图,忽略默认图(True) import tensorflow as tf g1 = tf.Graph() g2 = tf.Graph() with g1.as_default(): #定义g1中op, tensor等对象 ...... with g2.as_default(): #定义g2中op, tensor等对象 ...... ------------------------------------------------------------------------------ #获取默认图句柄(True) import tensorflow as tf g1 = tf.get_default_graph() g2 = tf.Graph() with g1.as_default(): #定义g1中op, tensor等对象 ...... with g2.as_default(): #定义g2中op, tensor等对象 ...... ------------------------------------------------------------------------------ #将默认图和用户创建的数据流图混用(False) import tensorflow as tf g1 = tf.Graph() #定义默认图op, tensor对象 ...... with g1.as_default(): #定义g1中op, tensor等对象 ......
此外, tensorflow加载之前定义好的模型(pb),可利用Graph.as_graph-def()和tf.import_graph_def()函数将其赋给Graph对象。这样,用户便可以在一个文件中计算和使用若干独立的模型输出。
-
Tensorflow 的 Session 对象
Session(会话)负责数据流图的执行工作。
tf.Session()接收三个可选参数,在一般的tensorflow程序中, 创建Session对象时无需改变任何默认参数,这里不加赘述。创建完session对象,便可以利用其主要的方法run()来计算输出。
#创建 Session 对象 ###下面两种调用方式是等价的,表示将使用当前的默认图 sess = tf.Session() sess = tf.Session(graph=tf.get_default_graph)
run()接收的参数 :
""" fetches 1.为了取回操作的输出内容, 可以在使用 Session 对象的 run() 调用 执行图时, 传入一些 tensor, 这些 tensor 会帮助你取回结果。 2.除了利用fetches获取tensor对象输出外,有时也会赋予其指向某个op句柄。 """ ------------------------------------------------------------------------------------------- import tensorflow as tf a = tf.constant(3.0) b = tf.constant(2.0) c = tf.add(a, b) mul = tf.mul(a, c) with tf.Session(): result = sess.run([mul, c]) print(result) # 输出: # [array([ 21.], dtype=float32), array([ 7.], dtype=float32)] ------------------------------------------------------------------------------------------- import tensorflow as tf #初始化所有定义的变量 sess.run(tf.initialize_all_variables())
""" Feed_dict TensorFlow还提供feed机制, 该机制可以临时替代图中的任意操作中的tensor可以对图中任何操作提交补丁, 直接插入一个tensor.feed使用一个tensor值临时替换一个操作的输出结果.你可以提供feed数据作为run()调用的参数. feed只在调用它的方法内有效, 方法结束, feed就会消失. 最常见的用例是将某些特殊的操作指定为 "feed"操作, 标记的方法是使用tf.placeholder()为这些操作创建占位符. """ import tensorflow as tf a = tf.placeholder(tf.float32) b = tf.placeholder(tf.float32) output = tf.mul(a, b) #用with语句块,Session对象使用完毕后自动释放资源 with tf.Session() as sess: #运行数据流图 result = sess.run([output], feed_dict={a:[7.], b:[2.]}) print(result) # 输出: # [array([ 14.], dtype=float32)]