gather_nd的定义如下:
def gather_nd(params, indices, name=None)
功能:根据indeces描述的索引,在params中提取元素,重新组成一个tansor
- indices将切片定义为params的前N维度,N为indices.shape[-1].
- 通常要求indices.shape[-1] <= params.rank(可以通过np.ndim(params)获得,也就是params的维度。)
– 如果等号成立,则取具体元素
– 如果小于成立,是在沿params的indices.shape[-1]轴进行切片
–换一个角度理解,也就是说indices的最后一维的元素做具体的索引,比如最后一维的第一行为[0,1],这个意思是,先在params的第一维中找第0个, 在得到的 结果中找第一维(实际维params的第二维)的第1个,以此类推。 - 返回的维度: indices.shape[:-1] + params.shape[indices.shape[-1]:]
– 前面的indices.shape[:-1]代表索引后的指定形状
举例:
data=
[[[1 1 1]
[2 2 2]]
[[3 3 3]
[4 4 4]]
[[5 5 5]
[6 6 6]]]
data shape is (3, 2, 3)
data rank is 3
indices = np.array([[0, 1], [1, 0]])
indices shape is (2, 2)
indices.shape[-1]= 2 < data rank 3
[0, 1]的索引过程第一个元素是选取data的第一维(aixs=0)的第0个元素得到[[1, 1, 1], [2, 2, 2]],再选取第二维的第1个元素,得到[2, 2, 2]
[1, 0]最后得到[3, 3, 3]
最后得到结果[[2, 2, 2], [3, 3, 3]]
另外一种理解的方式:
最后的切片的结果是indices中表示索引的部分被提取到的值替换后得到的结构。
还是以上面的例子说明这个思路,
indices为 [[0, 1], [1, 0]],其中斜体部分为最后一维即具体索引的部分,
[0, 1]索引得到[2, 2, 2]
[1, 0]索引得到[3, 3, 3]
把索引的结果替换到indices中得到
[[2, 2, 2], [3, 3, 3]]
当索引indices为 [[[[1,1]]]]时,
先找出[1, 1]的索引结果为[4,4,4]
替换到上面结构中得到 [[[[4, 4, 4]]]]
这两个思路,第一个比较正式,严谨但不是那么容易理解和操作
第二个思路理解起来比较简单,容易操作
参考链接:
https://zhuanlan.zhihu.com/p/45673869
https://blog.csdn.net/G66565906/article/details/84949512