import tensorflow as tf
A = [[1, 2, 3],
[4, 5, 6]]
B = [[1, 1],
[1, 1],
[1, 1]]
# 自定义图 my_graph
my_graph = tf.Graph()
# 在自定义图中添加 数据节点 和 计算节点
with my_graph.as_default():
A_tf = tf.constant(A, dtype=tf.int32)
B_tf = tf.constant(B, dtype=tf.int32)
C_tf = tf.matmul(A_tf, B_tf)
# 指定在自定义图中会话
with tf.Session(graph=my_graph) as sess:
C = sess.run(C_tf)
print(C)
# 输出为:
# [[ 6 6]
# [15 15]]
计算节点的输出数据对象会被放置到输入数据对象所在的图中
当然, 如下,在不同图中交叉引用数据节点 和 计算节点是会报错的。在构建图时,各个数据对象和计算节点对象必须在当前图中,不同图之间的资源不能交叉引用。
import tensorflow as tf
A = [[1, 2, 3],
[4, 5, 6]]
B = [[1, 1],
[1, 1],
[1, 1]]
# 自定义两个图 my_graph1,my_graph2
my_graph1 = tf.Graph()
my_graph2 = tf.Graph()
# 在自定义图 my_graph1 中添加 数据节点
with my_graph1.as_default():
A_tf = tf.constant(A, dtype=tf.int32)
B_tf = tf.constant(B, dtype=tf.int32)
# 在自定义图 my_graph2 中添加 计算节点
with my_graph2.as_default():
C_tf = tf.matmul(A_tf, B_tf)
print('C_tf.graph is my_graph1:', C_tf.graph is my_graph1)
print('C_tf.graph is my_graph2:', C_tf.graph is my_graph2)
# 输出为:
# C_tf.graph is my_graph1: True
# C_tf.graph is my_graph2: False