tf.gather()
首先来看官方的API说明
tf.gather(
params,
indices,
validate_indices=None,
axis=None,
batch_dims=0,
name=None
)
其中我们比较关心的是前两个参数,params和indices
- params :代表需要切片的张量;
- indices :切片的索引;
这里比较迷惑的是params和indices的维度,下面分情况讨论;
params的维度数不等于indices的维度数
例如:params=[1, 0, 1, 0]
indices=[
[3, 0, 3],
[1, 3, 2]
]
arr1 = np.array([
[[0.1,0.1,0.1,0.5],[0.9,0.2,0.3,0.8],[0.1,0.3,0.2,0.6]],
[[0.1,0.6,0.1,0.5],[0.7,0.2,0.3,0.8],[0.1,0.3,0.9,0.6]]
])
tesnsor1 = tf.Variable(arr1)
arr2 = np.array([[1,0,1,0],[1,0,1,0]])
tesnsor2 = tf.Variable(arr2)
res = tf.argmax(tesnsor1,axis=2) # [[3,0,3],[1,3,2]]
res = tf.gather(tesnsor2[0],res) # tesnsor2[0] = [1,0,1,0]
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
result = sess.run(res)
print(result)
print(result.shape)
result = [ [0, 1, 0], [0, 0, 1] ]
显示出:indices的第一行作用于params上,随后indices的第二行作用于params
循环进行。