Tensorflow小技巧整理:tf.cond()的小应用

Tensorflow小技巧整理:tf.cond()的小应用

tf.cond() 方法

tf.cond() 的作用类似于平常所使用的 if…else… 语句,但是在 tensorflow 中,所有节点是用图来保存的。而在图中传输的这些 tensor 数据流在我们使用 Session().run() 之前又是无法确定其数值的,所以这个时候传统的条件判断语句就无法使用。

比如我们想判断 a 和 b 是否相等:

a = tf.constant(3)
b = tf.constant(3)
# a 和 b 直接打印都是 <tf.Tensor '...' shape=() dtype=int32>

如果我们直接使用 ’ a == b ’ 判断,得到的是 ‘False’,而如果使用 ‘tf.equal(a,b)’ 来判断,返回的是一个新的张量:

is_equal = tf.equal(a,b)
# is_equal 是 <tf.Tensor  '...'  shape=()  dtype=bool>
# 如果这个时候我们使用if语句
if is_equal:
    print('a = b')

则会报错,错误信息为:

TypeError: Using a tf.Tensor as a Python bool is not allowed.

并会提示你使用 tf.cond 来处理条件判断。对于tf.cond() 方法,官方给出的是:

tf.cond(
    pred,
    true_fn=None,
    false_fn=None,
    strict=False,
    name=None,
    fn1=None,
    fn2=None
)

而实际上主要使用的只有前三个参数,所以如果我们简化为 tf.cond(pred, fn1, fn2) 的形式来看的话,这种形式像极了 java 中的 “?:” 的三元运算符。在经过 pred 条件进行筛选后,选择返回函数 fn1 还是 fn2 的值作为输出。

一个小例子

举个小例子,在做生成模型的时候,生成长语句是会出现重复连续输出相同词汇的问题,机器输出的语句就像结巴了一样,效果很差。传统方法会使用 Beam Search 算法,选择多个可能性的词句并进行存储,来增大输出的可能性。利用这个思想我们可以稍微简化一下,如果遇到重复我们直接依据概率选择第二可能的词汇输出,如果和上一个词没有重复则直接输出最大可能性的词汇。代码大致如下:

# phrase 是一个数组,存储每一次decoder输出的词汇
# this_word_output 是decoder过完softmax后的输出,通过这个输出选择词汇
# this_word_id 是经过 decoder 输出过 softmax 层后选择的最优词汇的 index
# next_word_vec 是生成词汇后,将生成词汇进行 embedding 转为词向量后作为下一步的输入
def f1():
    id_, next_word_vec = select_second_possible_word(this_word_output, this_word_id)
    return id_, next_word_vec
def f2():
    return this_word_id, next_word_vec
this_word_id, next_word_vec = tf.cond(tf.equal(this_word_id,phrase[-1]),f1,f2)

因为只是代码一部分,所以有些参数有些突兀。大体意思就是,phrase是一个数组来储存生成的所有词汇,最终连成一句完整的语句。 decoder输出一个向量并经过softmax选择最优词汇的 index(代码中用id表示)后,与phrase的最后一个词作比较,如果连续输出相同的词汇,则依据 select_second_possible_word() 函数来选择第二可能的词汇进行输出,返回的是对应的 index 和新的词向量。 这里判断部分的语句就是使用 tf.cond()方法,利用 tf.equal()来判别当前输出的 this_word_id 和 句子数组中最后一个单词是否重复,以此来决定当前输出的词汇是 f1 函数选择的次优词汇,还是 f2 函数直接输出的最优词汇。

这个方法比较简单粗暴,效果不如 Beam Search 算法,但是实现起来比较简单,在一定程度上改善了输出重复的问题,经过测试相较于每次直接选择最优输出,还是有所改进的。

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
这是一个线性回归模型的结果表格。该模型的因变量为 count,自变量为 temp。下面是对表格中的各项进行解释: - Dep. Variable: count:指定了模型的因变量为 count。 - R-squared: 0.156:决定系数 R-squared 为 0.156,表示自变量 temp 可解释因变量 count 的 15.6% 的方差。 - Model: OLS:模型采用最小二乘法(OLS)进行估计。 - Adj. R-squared: 0.156:校正后的决定系数为 0.156,表示模型已经考虑了自变量的个数和样本量对 R-squared 的影响。 - Method: Least Squares:估计方法为最小二乘法。 - F-statistic: 2006.00:F 统计量为 2006.00,用于检验自变量 temp 是否显著影响因变量 count,P 值为 0.00,表明自变量显著影响因变量。 - Prob (F-statistic): 0.00:P 值为 0.00,表明自变量显著影响因变量。 - Time: 13:53:24:模型估计的时间为 13:53:24。 - Log-Likelihood: -71125:对数似然值为 -71125,用于衡量模型的拟合程度。 - No. Observations: 10886:样本量为 10886。 - AIC: 1.423e+05:信息准则 AIC 为 1.423e+05,用于比较模型的优劣性。 - Df Residuals: 10884:剩余自由度为 10884。 - BIC: 1.423e+05:贝叶斯信息准则 BIC 为 1.423e+05,用于比较模型的优劣性。 - Df Model: 1:模型自由度为 1。 - Covariance Type: nonrobust:协方差类型为非鲁棒型,即不考虑异常值对模型的影响。 - coef:系数估计值。 - std err:系数估计值的标准误差。 - t:t 统计量的值。 - P>|t|:t 统计量的 P 值,用于检验系数是否显著。 - [0.025 0.975]:系数估计值的置信区间。 模型的 RMSE(均方根误差)为 161.63,表示模型的平均预测误差为 161.63。同时,l 的值也为 161.63,可能是你在表格中标注的一个符号。但是,我并不清楚它的含义和作用。希望我的解释能对你有所帮助。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值