Tensorflow深度学习之二十八:tf.cond

一、简介

def cond(pred, # 谓词,可以理解为判断条件
         true_fn=None, # 当谓词为真(True)时返回的函数
         false_fn=None, # 当谓词为假(False)时返回的函数
         strict=False, #
         name=None,
         fn1=None,
         fn2=None):

API注释:
Return true_fn() if the predicate pred is true else false_fn().

true_fn and false_fn both return lists of output tensors. true_fn and false_fn must have the same non-zero number and type of outputs.

Note that the conditional execution applies only to the operations defined in true_fn and false_fn. Consider the following simple program:

python
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))

If x < y, the tf.add operation will be executed and tf.square operation will not be executed. Since z is needed for at least one branch of the cond, the tf.multiply operation is always executed, unconditionally.
Although this behavior is consistent with the dataflow model of TensorFlow, it has occasionally surprised some users who expected a lazier semantics.

Note that cond calls true_fn and false_fn exactly once (inside the call to cond, and not at all during Session.run()). cond stitches together the graph fragments created during the true_fn and false_fn calls with some additional graph nodes to ensure that the right branch gets executed depending on the value of pred.

tf.cond supports nested structures as implemented in tensorflow.python.util.nest. Both true_fn and false_fn must return the same (possibly nested) value structure of lists, tuples, and/or named tuples.
Singleton lists and tuples form the only exceptions to this: when returned by true_fn and/or false_fn, they are implicitly unpacked to single values. This behavior is disabled by passing strict=True.

Google翻译:

如果谓词pred为真,则返回true_fn(),否则返回false_fn()

true_fnfalse_fn都返回输出张量列表。 true_fnfalse_fn必须具有相同的非零数字和输出类型。

请注意,条件执行仅适用于true_fnfalse_fn中定义的操作。考虑以下简单程序:

z = tf.multiply(a,b)
result = tf.cond(x <y,lambda:tf.add(x,z),lambda:tf.square(y))

如果x <y,将执行tf.add操作并且不执行tf.square操作。由于cond的至少一个分支需要z,所以总是无条件地执行tf.multiply操作。
虽然这种行为与TensorFlow的数据流模型一致,但它偶尔会让一些期望更加懒惰语义的用户感到惊讶。

注意cond只调用一次true_fnfalse_fncond的调用中,在Session.run()期间不调用)。 cond将在true_fnfalse_fn调用期间创建的图形片段与一些额外的图形节点拼接在一起,以确保根据pred的值执行正确的分支。

tf.cond支持在tensorflow.python.util.nest中实现的嵌套结构。 true_fnfalse_fn都必须返回列表,元组和/或命名元组的相同(可能是嵌套的)值结构。
单例列表和元组构成了对此的唯一例外:当由true_fn和/或false_fn返回时,它们被隐式解压缩为单个值。通过传递strict = True禁用此行为。

总结:该函数类似与if...else... 分支,当谓词判断为真时,调用前面一个函数,谓词判断为假时则调用后面一个函数。这在写程序时很有用,因为在TensorFlow中,我们需要先建立Graph,此时数据是不可知的,常规方法并不能直接判断,这里就提供了一个借口,可以在数据未知时进行判断。pred: A scalar determining whether to return the result of true_fn or
false_fn.
true_fn: The callable to be performed if pred is true.
false_fn: The callable to be performed if pred is false.
strict: A boolean that enables/disables ‘strict’ mode; see above.
name: Optional name prefix for the returned tensors.

二、参数
   在实际的使用过程中,我们一般只需要使用以下参数即可。

参数
predA scalar determining whether to return the result of true_fn or false_fn.一个标量,或者说是一个判断条件,用以判断返回true_fn 或者 false_fn
true_fnThe callable to be performed if pred is true.pred 为真时,返回的函数
false_fnThe callable to be performed if pred is false.pred 为假时,返回的函数
strictA boolean that enables/disables ‘strict’ mode; see above.一个bool值,表示是否使用’strict’模式,详见上
nameOptional name prefix for the returned tensors.名称,可选参数

三、代码

import tensorflow as tf
import numpy as np

x = tf.constant(2)
y = tf.constant(1)


def f1(): return tf.multiply(x, 17)


def f2(): return tf.add(y, 23)


r = tf.cond(tf.less(x, y), f1, f2)

with tf.Session() as sess:
    print(sess.run(r))

运行结果:因为2<1为False,执行f2,得到结果1+23=24

24
import tensorflow as tf
import numpy as np

x = tf.constant(2)
y = tf.constant(5) # 与前面程序的区别仅仅是y取值不同


def f1(): return tf.multiply(x, 17)


def f2(): return tf.add(y, 23)


r = tf.cond(tf.less(x, y), f1, f2)

with tf.Session() as sess:
    print(sess.run(r))

运行结果:因为2<5为True,这里执行f1,返回2*17=34。

34

为了方便,也可以使用lambda来定义函数。

# coding=utf-8
import tensorflow as tf
import numpy as np

a = tf.placeholder(dtype=tf.float32)

# 随便定义一些计算逻辑
b = tf.add(a, 32)

c = tf.add(a, 56)

res = tf.cond(a < 10, lambda: b + 10, lambda: c * 2)

with tf.Session() as sess:
    print(sess.run(res, feed_dict={a: 13}))

计算结果:

138.0
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
这是一个线性回归模型的结果表格。该模型的因变量为 count,自变量为 temp。下面是对表格中的各项进行解释: - Dep. Variable: count:指定了模型的因变量为 count。 - R-squared: 0.156:决定系数 R-squared 为 0.156,表示自变量 temp 可解释因变量 count 的 15.6% 的方差。 - Model: OLS:模型采用最小二乘法(OLS)进行估计。 - Adj. R-squared: 0.156:校正后的决定系数为 0.156,表示模型已经考虑了自变量的个数和样本量对 R-squared 的影响。 - Method: Least Squares:估计方法为最小二乘法。 - F-statistic: 2006.00:F 统计量为 2006.00,用于检验自变量 temp 是否显著影响因变量 count,P 值为 0.00,表明自变量显著影响因变量。 - Prob (F-statistic): 0.00:P 值为 0.00,表明自变量显著影响因变量。 - Time: 13:53:24:模型估计的时间为 13:53:24。 - Log-Likelihood: -71125:对数似然值为 -71125,用于衡量模型的拟合程度。 - No. Observations: 10886:样本量为 10886。 - AIC: 1.423e+05:信息准则 AIC 为 1.423e+05,用于比较模型的优劣性。 - Df Residuals: 10884:剩余自由度为 10884。 - BIC: 1.423e+05:贝叶斯信息准则 BIC 为 1.423e+05,用于比较模型的优劣性。 - Df Model: 1:模型自由度为 1。 - Covariance Type: nonrobust:协方差类型为非鲁棒型,即不考虑异常值对模型的影响。 - coef:系数估计值。 - std err:系数估计值的标准误差。 - t:t 统计量的值。 - P>|t|:t 统计量的 P 值,用于检验系数是否显著。 - [0.025 0.975]:系数估计值的置信区间。 模型的 RMSE(均方根误差)为 161.63,表示模型的平均预测误差为 161.63。同时,l 的值也为 161.63,可能是你在表格中标注的一个符号。但是,我并不清楚它的含义和作用。希望我的解释能对你有所帮助。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值