众所周知,Tensorflow中没有分支结构。实际搭建模型的过程中,可能会遇到需要根据张量的状态决定执行哪部分代码的情况,那我们应该如何应对这种情况呢?今天,我这个杂牌川军,就和大家分享一个利用Tensorflow API实现的分支结构,好不好用全看个人悟性(真不要脸,说得自己好像大神。)。Ladies and Gentmen, Let's begin!
首先,介绍我们的猪脚:tf.cond()。
我说它能实现分支,可能大家不信。毕竟人微言轻(屁都不算),所以祭出大杀器——Tensorflow官方描述(嘻嘻嘻)。 tf.cond(pred, true_fn=None, false_fn=None, strict=False, name=None, fn1=None, fn2=None)
如果pred为真,返回true_fn(),否则返回false_fn()。true_fn和false_fn都返回输出张量列表,同时true_fn和false_fn具有相同的非零数字和输出类型。
下面,我们举个栗子(公司无法上传图片,所以没有运行截图):
import tensorflow as tf
a = tf.constant(2)
b = tf.constant(3)
def fun_1():
return a + b
def fun_2():
return a * b
c = tf.cond(tf.less(a, b), fun_1, fun_2)
with tf.Session() as sess:
print(sess.run(c))
最终结果为5。
在这里需要注意的地方有:
1.true_fn、false_fn不能带参数;
2.true_fn、false_fn必须有返回值;
3.如果true_fn、false_fn需要参数,可以将参数放在true_fn、false_fn函数体的前面(正如蓝色加粗部分所示)。