tf.gather_nd的函数原型是:
def gather_nd(params, indices, name=None)
根据定义, 其主要功能是根据indices描述的索引,提取params上的元素, 重新构建一个tensor
在谈论该函数之前,我们先来看一下 索引的概念,
在一维数组中,元素的索引即该元素在数组中序号,通常序号从0开始标记
如数组 ary=[1,2,3,4];
元素2的索引 为 1, 元素的引用可表示为 [1]
元素3的索引为 2, 元素的引用可表示为 [2]
那么二维数组呢? 类似地
对于二维 ary=[ [1,2], [3,4] ]
元素 [1,2] 在一维中的索引为 [0], 元素 1 的索引 则表示为 [0,0], 元素 2 的索引 则表示为 [0,1],
因此 gather_nd 实现了根据指定的 参数 indices 来提取params 的元素重建出一个tensor,
还是以上面的二维数组为例
[0,0] 表示 的是 1,
[0,1] 表示的是 2
当indices 为 [[0,0],[0,1]] 时, 该函数的输出则为 [1,2]
即 indices 中 表示索引的 部分 被提取到的值替换
那么当indices 为[ [ [ [ [1,1] ] ] ] ] 时 函数输出是什么呢 ? 用元素 替换掉 表示索引的那一部分, 即可得到 [ [ [ [ 4 ] ] ] ]