import tensorflow as tf
1. 概述
Tensorflow中的运算过程, 会被表示成一个图, 由一些列的op(操作对象)作为节点组成. Tensor对象作为数据节点在op起到连接作用. 在程序运行的开始就会自动的有默认的图的生成.
print(tf.get_default_graph())
a = tf.constant(1.0)
print(a.graph)
'''
结果:
<tensorflow.python.framework.ops.Graph object at 0x0000024C257BB518>
<tensorflow.python.framework.ops.Graph object at 0x0000024C257BB518>
可见在程序运行的开始, 就已经创建了默认图, 并能通过tf.get_default_graph()方法得到默认图的对象.
创建的Tensor, Variable, Op等也都会加入到该图中.
'''
这种系统自动创建的方法不便于管理, 通常我们使用如下的方法, 在上下文中管理:
通过 tf.Graph()
创建新的Graph, 并用 as_default()
将这个Graph设置为默认的图, 使用with上下文管理器, 在这个上下文中覆盖默认的图, 将节点加入到新的tf.Graph()图中.
print(tf.get_default_graph())
g = tf.Graph()
print(g)
with g.as_default():
print(tf.get_default_graph())
print(tf.get_default_graph())
'''
结果:
<tensorflow.python.framework.ops.Graph object at 0x0000020DEFA2B5C0>
<tensorflow.python.framework.ops.Graph object at 0x0000020DEFBCD7B8>
<tensorflow.python.framework.ops.Graph object at 0x0000020DEFBCD7B8>
<tensorflow.python.framework.ops.Graph object at 0x0000020DEFA2B5C0>
'''
综上, 一般使用下述方法管理Graph:
with tf.Graph().as_default():
pass
2.图的函数
一个图中, 会有一个字典属性collection, 将图中的变量节点通过名字归类, 便于查找和区分, 即通过设置变量的collections属性, 将节点以字符串为key值归纳.
collection有以下默认字段:
variables
所有在该图下创建的Variable ;
trainable_variables
该图所有的可训练变量, 在创建该变量时trainable参数设置为True(默认)
Graph类关于collection的函数:
- tf.Graph.get_all_collection_keys()
'''
作用: 获取图中collection所有key值字符串
输出: string list, 元素为graph._collection字典中的key值字符串
'''
with tf.Graph().as_default():
a = tf.Variable([1,1], name="a")
b = tf.get_variable("b", shape=[2], collections=["test1", "test2"])
graph = tf.get_default_graph()
print(graph.get_all_collection_keys())
'''
结果:
['test1', 'variables', 'test2', 'trainable_variables']
'''
- tf.Graph.add_to_collection()
'''
作用: 将变量value添加到指定name一个collection中
参数:
name: (must)collection的名称, 只能输入一个collection的名称, 格式为字符串;
value: 变量对象
'''
with tf.Graph().as_default():
a = tf.Variable([1,1], name="a")
b = tf.get_variable("b", shape=[2], collections=["test1", "test2"])
graph = tf.get_default_graph()
graph.add_to_collection("test3", a)
print(graph.get_all_collection_keys())
'''
结果:
['test1', 'trainable_variables', 'test3', 'test2', 'variables']
'''
Graph类其他重要函数:
- tf.Graph.get_operations()
'''
作用: 返回图中所有op, 以列表的形式
输出: op list, 元素为op对象
'''
with tf.Graph().as_default():
W = tf.Variable([[1,2],[3,4]], dtype=tf.float32, name="W")
W2 = tf.get_variable("W2", initializer=tf.random_normal_initializer, shape=[2,2])
x = tf.placeholder(tf.float32, shape=[2,1], name="x")
b = tf.constant([8,6], dtype=tf.float32, shape=[2,1], name="b")
z = tf.matmul(tf.matmul(W, W2, name="matmul_inside"), x, name="matmul_outside") + b
y = tf.nn.relu(z, name="relu")
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
sess.run([y], feed_dict={x: [[1.0], [1.0]]})
graph = tf.get_default_graph()
print(graph.get_operations())
'''
结果:
[<tf.Operation 'W/initial_value' type=Const>, <tf.Operation 'W' type=VariableV2>, <tf.Operation 'W/Assign' type=Assign>, <tf.Operation 'W/read' type=Identity>, <tf.Operation 'W2/Initializer/random_normal/shape' type=Const>, <tf.Operation 'W2/Initializer/random_normal/mean' type=Const>, <tf.Operation 'W2/Initializer/random_normal/stddev' type=Const>, <tf.Operation 'W2/Initializer/random_normal/RandomStandardNormal' type=RandomStandardNormal>, <tf.Operation 'W2/Initializer/random_normal/mul' type=Mul>, <tf.Operation 'W2/Initializer/random_normal' type=Add>, <tf.Operation 'W2' type=VariableV2>, <tf.Operation 'W2/Assign' type=Assign>, <tf.Operation 'W2/read' type=Identity>, <tf.Operation 'x' type=Placeholder>, <tf.Operation 'b' type=Const>, <tf.Operation 'matmul_inside' type=MatMul>, <tf.Operation 'matmul_outside' type=MatMul>, <tf.Operation 'add' type=Add>, <tf.Operation 'relu' type=Relu>, <tf.Operation 'init' type=NoOp>]
可以看出, 有些函数方法会创造很多op, 比如tf.Variable()和tf.get_variable()虽然返回的是一个Variable, 但对应该变量的操作(如初始化, assign赋值, 以及Variable本身都是op)却有很多
'''
- tf.Graph.get_operation_by_name()
'''
作用: 通过op的名称获取op
参数:
name: op的名称
输出: op
'''
with tf.Graph().as_default():
W = tf.Variable([[1,2],[3,4]], dtype=tf.float32, name="W")
W2 = tf.get_variable("W2", initializer=tf.random_normal_initializer, shape=[2,2])
x = tf.placeholder(tf.float32, shape=[2,1], name="x")
b = tf.constant([8,6], dtype=tf.float32, shape=[2,1], name="b")
z = tf.matmul(tf.matmul(W, W2, name="matmul_inside"), x, name="matmul_outside") + b
y = tf.nn.relu(z, name="relu")
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
sess.run([y], feed_dict={x: [[1.0], [1.0]]})
graph = tf.get_default_graph()
print(graph.get_operations())
op1 = graph.get_operation_by_name("add")
print(op1)
'''
结果:
name: "add"
op: "Add"
input: "matmul_outside"
input: "b"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
这个'add'op是W*W2+b时产生的
'''
- tf.Graph.get_tensor_by_name()
'''
作用: 通过名称获取tensor
参数:
name: Tensor的名称, 格式为<op_name>:<output_index>, Tensor是由各种Op产生的
输出: Tensor, 指定获取的Tensor
'''
with tf.Graph().as_default():
W = tf.Variable([[1,2],[3,4]], dtype=tf.float32, name="W")
W2 = tf.get_variable("W2", initializer=tf.random_normal_initializer, shape=[2,2])
x = tf.placeholder(tf.float32, shape=[2,1], name="x")
b = tf.constant([8,6], dtype=tf.float32, shape=[2,1], name="b")
z = tf.matmul(tf.matmul(W, W2, name="matmul_inside"), x, name="matmul_outside") + b
y = tf.nn.relu(z, name="relu")
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
sess.run([y], feed_dict={x: [[1.0], [1.0]]})
graph = tf.get_default_graph()
tensor = graph.get_tensor_by_name("b:0")
print(tensor)
print(tensor.eval())
'''
结果:
Tensor("b:0", shape=(2, 1), dtype=float32)
[[ 8.]
[ 6.]]
'''
- tf.Graph.is_feedable()
- tf.Graph.is_fetchable()
'''
作用:
is_feedable: 判断Tensor能否被feed;
is_fetchable: 判断Tensor或Op能否被fetch;
参数:
is_feedable:
name: Tensor名称或Tensor类
is_fetchable
name: Tensor或Op的名称或类
输出: bool
'''
with tf.Graph().as_default():
W = tf.Variable([[1,2],[3,4]], dtype=tf.float32, name="W")
W2 = tf.get_variable("W2", initializer=tf.random_normal_initializer, shape=[2,2])
x = tf.placeholder(tf.float32, shape=[2,1], name="x")
b = tf.constant([8,6], dtype=tf.float32, shape=[2,1], name="b")
z = tf.matmul(tf.matmul(W, W2, name="matmul_inside"), x, name="matmul_outside") + b
y = tf.nn.relu(z, name="relu")
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
sess.run([y], feed_dict={x: [[1.0], [1.0]]})
graph = tf.get_default_graph()
tensor1 = graph.get_tensor_by_name("W2:0")
tensor2 = graph.get_tensor_by_name("x:0")
print(tensor1, graph.is_feedable("W2:0"), graph.is_fetchable(tensor1))
print(tensor2, graph.is_feedable(tensor2), graph.is_fetchable("x:0"))
op1 = graph.get_operation_by_name("relu")
op2 = graph.get_operation_by_name("W")
print("op1", graph.is_feedable("relu"), graph.is_fetchable(op1))
print("op2", graph.is_feedable(op2), graph.is_fetchable("W"))
'''
结果:
Tensor("W2:0", shape=(2, 2), dtype=float32_ref) True True
Tensor("x:0", shape=(2, 1), dtype=float32) True True
op1 True True
op2 True True
'''
- tf.Graph.device()
'''
作用: 在with上下文管理器中使用, 在with下创建的Op使用graph.device()指定的设备
参数:
device_name_or_function: 设备的名称, 形如"/gpu:0"或"/cpu:0", 不能使用list传递多个
'''
with tf.Graph().as_default():
graph = tf.get_default_graph()
with graph.device("/cpu:0"):
a = tf.constant(1.0)
print(a.device)
'''
结果:
/device:CPU:0
'''
# 一般地, 我们直接使用tf.device()来指定op的运行设备
with tf.device("/cpu:0"):
a = tf.constant(1.0)
print(a.device)
'''
结果:
/device:CPU:0
'''
- tf.Graph.name_scope()
'''
作用: 返回一个上下文管理器, 为图中的Op创建层级的名称;
参数:
name: 该层scope的名称
注意:
1. name_scope只给Op(操作)加前缀(tf.Variable也是操作), 不会给通过tf.get_variable()获取的Variable加前缀(没有Op产生这个变量, 或者产生这个变量的Op不受管理?)
2. tf.Graph中没有variable_scope()方法
'''
with tf.Graph().as_default() as g:
a = tf.constant(1.0, name="a")
with g.name_scope("level1_a"):
b = tf.constant([1.1, 1.2], name="b")
c = tf.Variable(0 ,name="c")
with g.name_scope("level2_a"):
d = tf.get_variable(shape=[3, 3, 3], name="d")
e = tf.nn.relu(b, name="e")
f = tf.zeros(shape=[3, 3], dtype=tf.float32, name="f")
print(a.op.name)
print(b.op.name)
print(c.op.name)
print(d.op.name)
print(e.op.name)
print(f.op.name)
'''
结果:
a
level1_a/b
level1_a/c
d
level1_a/level2_a/e
level1_a/level2_a/f
可以看出, name_scope()方法只给Op加前缀, 而由get_varialbe()方法创建的d, 就没有被加前缀
'''
- tf.Graph.finalize()
'''
作用: 结束图, 使得图只读(read-only). 调用 g.finalize()之后,新的操作就不能够添加到g里面去了
'''
with tf.Graph().as_default() as g:
a = tf.constant(1.0, name="a")
g.finalize()
b = tf.constant(1.0, name="b")
'''
结果: (报错)
RuntimeError: Graph is finalized and cannot be modified.
'''