关于如何获取某一个tensor,大家可能很快就能想到使用tf.get_default_graph().get_tensor_by_name()
这个函数。对于普通的tensor或者op其实到这就完成了。但由于layer的一些特殊性,我们需要进一步的了解。那就从最直觉的做法,看看会发生什么。
一个BUG
import tensorflow as tf
import numpy as np
inputs = tf.placeholder(tf.float32, [None, 3])
layer_out = tf.layers.Dense(5, name="outputs_1")(inputs)
output = tf.layers.Dense(1, name="outputs_2")(layer_out)
with tf.Session() as sess:
sess.run(init)
mat = np.ones([2,3])
fout = sess.run(tf.get_default_graph().get_tensor_by_name("outputs_1:0"), feed_dict={inputs:mat})
print(fout)
正确运行后,你就收到如下的报错提示
KeyError: " The name ‘outputs_1:0’ refers to a Tensor which dose not exist. The operation, ‘outputs_1’ does not exist in the graph."
此时,你肯定地铁老人看手机,满脸不理解。怎么可能,我明明指定了name='outputs_1’啊。
我们不妨冷静地看看报错信息,这里面提到outputs_1不作为一个op存在。然后你可以试试将 get_tensor_by_name括号中换成outputs_1/kerner:0。诶,这会正常输出了,这至少说明了name那边指定是正常的。
那顺着报错信息思考,虽然指定了name,但会不会这个name确实不代表op,而是代表其他。
layers.dense中name指代什么
那么我们就来看layers.dense这里的name指代的到底是啥。
dense层具有偏差bias和权重weight参数,因此在命名时,是将命名整个layer为“outputs_1”,而并不是对这个layer的输出张量tensor进行命名。
将名为"outpus_1"的结点展开后为:
所以,到此我们就能知道报错的原因了。但是如何实现我们的需求呢?
获取中间层输出
在这里有两种办法,一种是在网络层定义的过程进行修改,一种是利用默认命名规则实现。
方法一
使用tf.identity()函数来定义一个网络层输出的替代节点。具体如下:
import tensorflow as tf
import numpy as np
inputs = tf.placeholder(tf.float32, [None, 3])
layer_out = tf.layers.Dense(5, name="outputs_1")(inputs)
fl_out = tf.identity(layer_out, name="fout")
output = tf.layers.Dense(1, name="outputs_2")(layer_out)
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
mat = np.ones([2,3])
fout = sess.run(tf.get_default_graph().get_tensor_by_name("fout:0"), feed_dict={inputs:mat})
print(fout)
方法二
我们观察上面的网络节点图,我们可以发现BiasAdd
这个默认命名的结点其实就是该层最终的输出结果节点,所以我们也可以直接根据该名字进行获取。
下面的代码显示了如何获取该结点并基于该结点进行额外的操作。
import tensorflow as tf
import numpy as np
inputs = tf.placeholder(tf.float32, [None, 3])
layer_out = tf.layers.Dense(5, name="outputs_1")(inputs)
output = tf.layers.Dense(1, name="outputs_2")(layer_out)
graph = tf.get_default_graph()
tmp = graph.get_tensor_by_name("outputs_1/BiasAdd:0")
ans = tf.add(tmp, tf.ones([2,5]))
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
mat = np.ones([2,3])
[ans, fout] = sess.run([ans, tf.get_default_graph().get_tensor_by_name("outputs_1/BiasAd:0")], feed_dict={inputs:mat})
print(fout)
print(ans)