feed_dict参数的作用是替换图中的某个tensor的值。例如:
a = tf.add(2, 5)
b = tf.multiply(a, 3)
with tf.Session() as sess:
sess.run(b)
21
replace_dict = {a: 15}
sess.run(b, feed_dict = replace_dict)
45
这样做的好处是在某些情况下可以避免一些不必要的计算。
除此之外,feed_dict还可以用来设置graph的输入值,这就引入了
x = tf.placeholder(tf.float32, shape=(1, 2))
w1 = tf.Variable(tf.random_normal([2, 3],stddev=1,seed=1))
w2 = tf.Variable(tf.random_normal([3, 1],stddev=1,seed=1))
a = tf.matmul(x,w1)
y = tf.matmul(a,w2)
with tf.Session() as sess:
# 变量运行前必须做初始化操作
init_op = tf.global_variables_initializer()
sess.run(init_op)
print(sess.run(y, feed_dict={x:[[0.7, 0.5]]}))
[[3.0904665]]
或者 多输入
x = tf.placeholder(tf.float32, shape=(None, 2))
w1 = tf.Variable(tf.random_normal([2,3],stddev=1,seed=1))
w2 = tf.Variable(tf.random_normal([3,1],stddev=1,seed=1))
a = tf.matmul(x,w1)
y = tf.matmul(a,w2)
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
print(sess.run(y, feed_dict={x:[[0.7,0.5],[0.2,0.3],[0.3,0.4],[0.4,0.5]]}))
print(sess.run(w1))
print(sess.run(w2))
[[3.0904665]
[1.2236414]
[1.7270732]
[2.2305048]]
[[-0.8113182 1.4845988 0.06532937]
[-2.4427042 0.0992484 0.5912243 ]]
[[-0.8113182 ]
[ 1.4845988 ]
[ 0.06532937]]
注意:此时的a不是一个tensor,而是一个placeholder。我们定义了它的type和shape,但是并没有具体的值。在后面定义graph的代码中,placeholder看上去和普通的tensor对象一样。在运行程序的时候我们用feed_dict的方式把具体的值提供给placeholder,达到了给graph提供input的目的。
placeholder有点像在定义函数的时候用到的参数。我们在写函数内部代码的时候,虽然用到了参数,但并不知道参数所代表的值。只有在调用函数的时候,我们才把具体的值传递给参数。