Tensorflow小技巧整理:tf.cond()的小应用
tf.cond() 方法
tf.cond() 的作用类似于平常所使用的 if…else… 语句,但是在 tensorflow 中,所有节点是用图来保存的。而在图中传输的这些 tensor 数据流在我们使用 Session().run() 之前又是无法确定其数值的,所以这个时候传统的条件判断语句就无法使用。
比如我们想判断 a 和 b 是否相等:
a = tf.constant(3)
b = tf.constant(3)
# a 和 b 直接打印都是 <tf.Tensor '...' shape=() dtype=int32>
如果我们直接使用 ’ a == b ’ 判断,得到的是 ‘False’,而如果使用 ‘tf.equal(a,b)’ 来判断,返回的是一个新的张量:
is_equal = tf.equal(a,b)
# is_equal 是 <tf.Tensor '...' shape=() dtype=bool>
# 如果这个时候我们使用if语句
if is_equal:
print('a = b')
则会报错,错误信息为:
TypeError: Using a
tf.Tensor
as a Pythonbool
is not allowed.
并会提示你使用 tf.cond 来处理条件判断。对于tf.cond() 方法,官方给出的是:
tf.cond(
pred,
true_fn=None,
false_fn=None,
strict=False,
name=None,
fn1=None,
fn2=None
)
而实际上主要使用的只有前三个参数,所以如果我们简化为 tf.cond(pred, fn1, fn2) 的形式来看的话,这种形式像极了 java 中的 “?:” 的三元运算符。在经过 pred 条件进行筛选后,选择返回函数 fn1 还是 fn2 的值作为输出。
一个小例子
举个小例子,在做生成模型的时候,生成长语句是会出现重复连续输出相同词汇的问题,机器输出的语句就像结巴了一样,效果很差。传统方法会使用 Beam Search 算法,选择多个可能性的词句并进行存储,来增大输出的可能性。利用这个思想我们可以稍微简化一下,如果遇到重复我们直接依据概率选择第二可能的词汇输出,如果和上一个词没有重复则直接输出最大可能性的词汇。代码大致如下:
# phrase 是一个数组,存储每一次decoder输出的词汇
# this_word_output 是decoder过完softmax后的输出,通过这个输出选择词汇
# this_word_id 是经过 decoder 输出过 softmax 层后选择的最优词汇的 index
# next_word_vec 是生成词汇后,将生成词汇进行 embedding 转为词向量后作为下一步的输入
def f1():
id_, next_word_vec = select_second_possible_word(this_word_output, this_word_id)
return id_, next_word_vec
def f2():
return this_word_id, next_word_vec
this_word_id, next_word_vec = tf.cond(tf.equal(this_word_id,phrase[-1]),f1,f2)
因为只是代码一部分,所以有些参数有些突兀。大体意思就是,phrase是一个数组来储存生成的所有词汇,最终连成一句完整的语句。 decoder输出一个向量并经过softmax选择最优词汇的 index(代码中用id表示)后,与phrase的最后一个词作比较,如果连续输出相同的词汇,则依据 select_second_possible_word() 函数来选择第二可能的词汇进行输出,返回的是对应的 index 和新的词向量。 这里判断部分的语句就是使用 tf.cond()方法,利用 tf.equal()来判别当前输出的 this_word_id 和 句子数组中最后一个单词是否重复,以此来决定当前输出的词汇是 f1 函数选择的次优词汇,还是 f2 函数直接输出的最优词汇。
这个方法比较简单粗暴,效果不如 Beam Search 算法,但是实现起来比较简单,在一定程度上改善了输出重复的问题,经过测试相较于每次直接选择最优输出,还是有所改进的。