Tensorflow 中使用tf.cond来控制数据的流向,类似于C语言中的if…else…
语法
format:tf.cond(pred, fn1, fn2, name=None)
例子:
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
上面例子执行这样的操作,如果x<y则result这个操作是tf.add(x,z),反之则是tf.square(y)
z在cond只要至少被一个分支用到,则z被执行。
在实际使用中,可以结合placeholder和feed进行参数的传入,决定计算图中数据流的走向。
is_trainning = tf.placeholder(tf.int16)
feed = {inputs: batch_train_inputs,
targets: batch_train_targets,
is_trainning : 1}
batch_cost, _, train_accuracy= sess.run([cost, optimizer,accuracy], feed)
logit = tf.cond(is_trainning > 0,
lambda: final_fc,
lambda: tf.nn.softmax(final_fc),
name = 'logits')
注意不能tf.bool作为判断条件,因为传入的feed不能为Python bool(简单的False or True)
最简单的解决方法还是传递一个数值。