tf.where()函数的语法格式如下:
import tensorflow as tf
tf.where(
condition,
x=None,
y=None,
name=None
)
作用:该函数的作用是根据condition,返回相对应的x或y,返回值是一个tf.bool类型的Tensor。
例1:
import tensorflow as tf
sess=tf.Session()
A =tf.where(False,123,321)
>>> print(A)
Tensor("Select:0", shape=(), dtype=int32)
>>> print(sess.run(A))
321
>>> B=tf.where(True,123,321)
>>> print(sess.run(B))
123
sess.close()
例2:
sess=tf.Session()
>>> X = [["China","Henan","Changsha"],
... ["You","Love","China"]]
>>> Y = [["America","Shanxi","lvliang"],
... ["I","Like","Country"]]
print(sess.run(tf.where(condition_1,X,Y)))
[[b'China' b'Shanxi' b'lvliang']
[b'I' b'Love' b'China']]
sess.close()
由以上两个例子我们可以清楚地看到,tf.where()的作用就是根据condition返回相对应的X 或 Y值。若condition=True,则返回对应X的值,False则返回对应的Y值。
以上内容,如有错误,敬请批评指正!谢谢!