Tensorflow小技巧整理:tf.cond()的小应用

Tensorflow小技巧整理:tf.cond()的小应用

tf.cond() 方法

tf.cond() 的作用类似于平常所使用的 if…else… 语句,但是在 tensorflow 中,所有节点是用图来保存的。而在图中传输的这些 tensor 数据流在我们使用 Session().run() 之前又是无法确定其数值的,所以这个时候传统的条件判断语句就无法使用。

比如我们想判断 a 和 b 是否相等:

a = tf.constant(3)
b = tf.constant(3)
# a 和 b 直接打印都是 <tf.Tensor '...' shape=() dtype=int32>

如果我们直接使用 ’ a == b ’ 判断,得到的是 ‘False’,而如果使用 ‘tf.equal(a,b)’ 来判断,返回的是一个新的张量:

is_equal = tf.equal(a,b)
# is_equal 是 <tf.Tensor  '...'  shape=()  dtype=bool>
# 如果这个时候我们使用if语句
if is_equal:
    print('a = b')

则会报错,错误信息为:

TypeError: Using a tf.Tensor as a Python bool is not allowed.

并会提示你使用 tf.cond 来处理条件判断。对于tf.cond() 方法,官方给出的是:

tf.cond(
    pred,
    true_fn=None,
    false_fn=None,
    strict=False,
    name=None,
    fn1=None,
    fn2=None
)

而实际上主要使用的只有前三个参数,所以如果我们简化为 tf.cond(pred, fn1, fn2) 的形式来看的话,这种形式像极了 java 中的 “?:” 的三元运算符。在经过 pred 条件进行筛选后,选择返回函数 fn1 还是 fn2 的值作为输出。

一个小例子

举个小例子,在做生成模型的时候,生成长语句是会出现重复连续输出相同词汇的问题,机器输出的语句就像结巴了一样,效果很差。传统方法会使用 Beam Search 算法,选择多个可能性的词句并进行存储,来增大输出的可能性。利用这个思想我们可以稍微简化一下,如果遇到重复我们直接依据概率选择第二可能的词汇输出,如果和上一个词没有重复则直接输出最大可能性的词汇。代码大致如下:

# phrase 是一个数组,存储每一次decoder输出的词汇
# this_word_output 是decoder过完softmax后的输出,通过这个输出选择词汇
# this_word_id 是经过 decoder 输出过 softmax 层后选择的最优词汇的 index
# next_word_vec 是生成词汇后,将生成词汇进行 embedding 转为词向量后作为下一步的输入
def f1():
    id_, next_word_vec = select_second_possible_word(this_word_output, this_word_id)
    return id_, next_word_vec
def f2():
    return this_word_id, next_word_vec
this_word_id, next_word_vec = tf.cond(tf.equal(this_word_id,phrase[-1]),f1,f2)

因为只是代码一部分,所以有些参数有些突兀。大体意思就是,phrase是一个数组来储存生成的所有词汇,最终连成一句完整的语句。 decoder输出一个向量并经过softmax选择最优词汇的 index(代码中用id表示)后,与phrase的最后一个词作比较,如果连续输出相同的词汇,则依据 select_second_possible_word() 函数来选择第二可能的词汇进行输出,返回的是对应的 index 和新的词向量。 这里判断部分的语句就是使用 tf.cond()方法,利用 tf.equal()来判别当前输出的 this_word_id 和 句子数组中最后一个单词是否重复,以此来决定当前输出的词汇是 f1 函数选择的次优词汇,还是 f2 函数直接输出的最优词汇。

这个方法比较简单粗暴,效果不如 Beam Search 算法,但是实现起来比较简单,在一定程度上改善了输出重复的问题,经过测试相较于每次直接选择最优输出,还是有所改进的。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值