tf.where
tf.where(condition, x=None, y=None, name=None)
# condition, x, y 相同维度,condition是bool型值
# 返回condition中元素为True对应的索引
>>> condition1 = [[True,False,False],
[False,True,True]]
[[0 0]
[1 1]
[1 2]]
# 如果有 x y 输入,condition为True用x的对应位置替换,为False则用y
# 下例:
import tensorflow as tf
x = [[1,2,3],[4,5,6]]
y = [[7,8,9],[10,11,12]]
condition3 = [[True,False,False],
[False,True,True]]
condition4 = [[True,False,False],
[True,True,False]]
with tf.Session() as sess:
print(sess.run(tf.where(condition3,x,y)))
print(sess.run(tf.where(condition4,x,y)))
# 输出:
1, [[ 1 8 9]
[10 5 6]]
2, [[ 1 8 9]
[ 4 5 12]]
tf.gather 和 tf.gather_nd
这俩都是通过索引来切片的方法:
tf.gather(params,indices,axis=0 )
# 从params的axis维根据indices的参数值获取切片
示例:
import numpy as np
import tensorflow as tf
probs = np.array([
[0, 11, 21, 31, 41, 51, 61, 71, 81],
[0, 12, 22, 32, 42, 52, 62, 72, 82],
[0, 13, 23, 33, 43, 53, 63, 73, 83],
[0, 14, 24, 34, 44, 54, 64, 74, 84]
])
indices_nd = np.array([
[0, 7],
[1, 6],
[2, 6],
[3, 1]
])
indices_0 = np.array([1, 3])
indices_1 = np.array([7, 3])
with tf.Session() as sess:
print("tf.gather axis=0 \n", sess.run(tf.gather(probs, indices_0, axis=0)))
print("tf.gather axis=1 \n", sess.run(tf.gather(probs, indices_1, axis=1)))
print("tf.gather_nd", sess.run(tf.gather_nd(probs, indices_nd)))
输出:
tf.gather axis=0
[[ 0 12 22 32 42 52 62 72 82]
[ 0 14 24 34 44 54 64 74 84]]
tf.gather axis=1
[[71 31]
[72 32]
[73 33]
[74 34]]
tf.gather_nd [71 62 63 14]