tf.cond()条件控制语句执行出错

问题现场:

import tensorflow as tf
parsed_example = {
    "label":[[1, 1, 1, 0],
              [0, 1, 1, 0],
             [1, 1, 1, 0]],
    "ot": [[1, 1, 1, -1],
              [0, 1, 1, 0],
             [1, 1, 1, 0]],
}
label = tf.slice(tf.cast(parsed_example["label"], tf.int32), [0, 0], [-1, 1])
sess = tf.Session()
label = sess.run(label)
print label

ot = tf.slice(tf.cast(parsed_example["label"], tf.int32), [0, 0], [-1, 1])
ot = sess.run(ot)
print ot
test = tf.cond(ot == 1, 0, 1)
test = sess.run(test)
print test

解决:

typyerror表示:true_fnfalse_fn是不可调用的。(说明上面的代码中的第二个参数和第三个参数类型错误,并且不是tensor,返回结果必须是tensor。)

test = tf.cond(ot == ot1, lambda: tf.add(ot, ot), lambda : tf.add(ot1, ot1))

将fn改为参数后,报错:

    raise TypeError("pred must not be a Python bool")
TypeError: pred must not be a Python bool

解决:

个类型错误,或者不兼容的问题,Python中的True不是tf.bool类型,所以导致不兼容,只要定义的时候加上type就行了,举个例子:

tf.cast(True, tf.bool)

tf.cond()用法

在TensorFlow中,tf.cond()类似于c语言中的if...else...,用来控制数据流向,但是仅仅类似而已,其中差别还是挺大的。关于tf.cond()函数的具体操作,参考了tf的说明文档。

重点:

  • 输入参数中pred的类型必须是tensorflow的bool类型(tf.bool,可通过tf.cast(True, tf.bool)来转换)
  • fnq'和‘fn2’返回的是非零的且类型相同的输出。得是函数的形式,如果你输出是常量,可以使用匿名函数的形式作为函数

例子:

经过debug,上述例子可以运行的代码:

import tensorflow as tf
parsed_example = {
    "label":[[1, 1, 1, 0],
              [0, 1, 1, 0],
             [1, 1, 1, 0]],
    "ot": [[1, 1, 1, -1],
              [0, 1, 1, 0],
             [1, 1, 1, 0]],
}
label = tf.slice(tf.cast(parsed_example["label"], tf.int32), [0, 0], [-1, 1])
sess = tf.Session()
label = sess.run(label)
# print label

ot = tf.slice(tf.cast(parsed_example["label"], tf.int32), [0, 0], [-1, 1])
ot1 = tf.slice(tf.cast(parsed_example["label"], tf.int32), [0, 1], [-1, 1])
# ot = sess.run(ot)
# print ot
a = tf.add(ot, ot)
b = tf.add(ot1, ot1)
print sess.run(a)
print sess.run(b)
print ot == ot1
print ot == 1

test = tf.cond(tf.cast(ot == ot1, tf.bool), lambda: tf.add(ot, ot), lambda : tf.add(ot1, ot1))
test = sess.run(test)
print test

倒数第三句代码中,对输出常量,可以改写为匿名函数的形式:

test = tf.cond(tf.cast(ot == ot1, tf.bool), lambda: 1, lambda : 2)

参考:

1.https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/cond

2.帮助理解:https://blog.csdn.net/m0_37041325/article/details/76908660

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值