tf.gather_nd和tf.gather的区别与联系
tf.gather_nd和tf.gather都是选择其中需要的部分组成一个新的矩阵。
其中tf.gather(params,indices,axis=0.name=None)是选择相同维度的元素组成新的矩阵,params是待处理tensor,indices是一个1-D tensor表示对应轴上的索引,axis是指定维度,name是操作名字。
tf.gather_nd(params,indices,可以选择不同维度的元素,其中最重要的两个参数是params和indices,params是待处理矩阵,indices则是params对应的索引,这个就可能不是1-D tensor了,是针对整个tensor的索引。
tf.gather_nd这个概念看官网总感觉不能透彻理解,而且官网没有给出3-D tensor以上的算法,经过思考和代码实践,我发现无论维度多少,可以简单理解为将索引部分替换为对应的元素。
拿一个官网例子,如下:
indices = [[0, 0], [1, 1]]
params = [['a', 'b'], ['c', 'd']]
output = ['a', 'd']
按照之前说的替换,就是拿出params中索引[0,0],[1,1]对应的元素组成一个新矩阵,就是’a’,’d’组成新矩阵就是[‘a’,’d’]
除了索引到具体元素的,还可以选向量或者更高维的tensor,再拿一个官网例子
indices = [[1]]
params