上代码,利用tf.gather_nd函数:
eg1: 常数索引
import tensorflow as tf
sess=tf.Session()
a=tf.constant([[0,1,2],[3,4,5]])#shape:(2,3)
result=tf.gather_nd(a,[0,1])#1
sess.run(result)
eg2:变量索引
import tensorflow as tf
sess=tf.Session()
a=tf.constant([[0,1,2],[3,4,5]])#shape:(2,3)
b=tf.Variable([0,1],dtype=tf.int32)
sess.run(id.initializer)
result=tf.gather_nd(a,b)#1
sess.run(result)#1