tf.gather
Gather slices from params axis according to indices.
tf.gather(params, indices, validate_indices=None, axis=None, batch_dims=0, name=None)
Example:
# data: [classes, students, subjects]
data = tf.ones([4,35,8])
print(data.shape) # TensorShape([4, 35, 8])
# sample several classes
data = tf.gather(data, axis=0, indices=[2,3])
print(data.shape) # TensorShape([2, 35, 8])
tf.gather_nd
Gather slices from params into a Tensor with shape specified by indices.
tf.gather_nd(params, indices, batch_dims=0, name=None)
Example:
# data: [classes, students, subjects]
data = tf.ones([4,35,8])
print(data.shape) # TensorShape([4, 35, 8])
# sample several (classes and students)
# for instance: [class1_student1, class2_studnet2, class3_student3, class4_student4]
data = tf.gather_nd(data, [[0,0],[1,1],[2,2],[3,3]])
print(data.shape) # TensorShape([4, 8])