错误描述:
在初学tensorflow的过程中,有时执行session.run(tensor_xxx)时,会报出"Tensor XXX is not an element of this graph"这个错误。
原因分析:
从错误的描述可以看出,错误原因是tensor_xxx不在"this graph"里。但是,读到这,一个很自然的问题是"this graph"指的是哪个graph呢?
"this graph"其实指的是在创建session时传入该session里的graph。由session的init函数__init__(target='',graph=None, config=None)可以看出,创建session时,我们需要传入一个图graph,这个图限定了的该session的处理范围——只能处理这个图里的tensor。当graph=None时,传给session的图是tf.get_default_graph()。
解决方案:
有两种方法可以解决这个问题。
方法1:
新建一个session,把tensor_xxx所在的graph传给该session,然后执行session.run(tensor_xxx),示例代码如下:
import tensorflow as tf
graph_tensor = tf.Graph()
with graph_tensor.as_default():
A = tf.constant(1)
sess = tf.Session(graph=graph_tensor)
sess.run(A)
方法2:
如果我们想在一个session中执行不同图里的tensor(例如将两个计算图的计算结果求和),上述方法就不好使了。此时,我们需要把两个图“合并”成一张图,然后传给一个session,示例代码如下:
import tensorflow as tf
with tf.Graph().as_default():
xxx_tensor = tf.constant([1, 2, 3])
ops = {"xxx_tensor": xxx_tensor}
for name, op in ops.items():
tf.add_to_collection(name, op)
metagraph = tf.train.export_meta_graph() # 导出当前图
with tf.Graph().as_default():
tf.train.import_meta_graph(metagraph) # 把一个图导入到当前图
xxx_tensor = tf.get_collection_ref("xxx_tensor")
sv = tf.train.Supervisor()
with sv.managed_session() as session:
fetches = {"xxx_tensor": xxx_tensor}
feed_dict = {}
#session.run(init)
vals = session.run(fetches, feed_dict)
rs = vals["xxx_tensor"]