1、tf.where()用法
tf.where(condition, x=None, y=None, name=None)
- condition, x, y 相同维度,condition是bool型值,True/False
- 返回值是对应元素
- condition=True,元素取值为x中的元素
- condition=False,元素取值为y中的元素
代码:
import tensorflow as tf
x = [[1,2,3],[4,5,6]]
y = [[7,8,9],[10,11,12]]
condition_all_x = [[True,True,True], [True,True,True]] #取x中全部元素
condition_all_y = [[False,False,False], [False,False,False]] #取y中全部元素
condition_x_y = [[True,True,True], [False,False,False]] #取x、y中全部元素各一半
condition1 = [[True,False,False], [False,True,True]]
condition2 = [[True,False,False], [True,True,False]]
with tf.Session() as sess:
print("取值全部为x中元素:\n",sess.run(tf.where(condition_all_x,x,y)))
print("取值全部为y中元素:\n",sess.run(tf.where(condition_all_y,x,y)))
print("取值为x/y中元素各一半:\n",sess.run(tf.where(condition_all_y,x,y)))
print(sess.run(tf.where(condition1,x,y)))
print(sess.run(tf.where(condition2,x,y)))
输出: