tf.gather_nd用法详解

gather_nd的定义如下:

def gather_nd(params, indices, name=None)

功能:根据indeces描述的索引,在params中提取元素,重新组成一个tansor

  • indices将切片定义为params的前N维度,N为indices.shape[-1].
  • 通常要求indices.shape[-1] <= params.rank(可以通过np.ndim(params)获得,也就是params的维度。)
    – 如果等号成立,则取具体元素
    – 如果小于成立,是在沿params的indices.shape[-1]轴进行切片
    –换一个角度理解,也就是说indices的最后一维的元素做具体的索引,比如最后一维的第一行为[0,1],这个意思是,先在params的第一维中找第0个, 在得到的 结果中找第一维(实际维params的第二维)的第1个,以此类推。
  • 返回的维度: indices.shape[:-1] + params.shape[indices.shape[-1]:]
    – 前面的indices.shape[:-1]代表索引后的指定形状

举例:
data=

[[[1 1 1]
  [2 2 2]]
 [[3 3 3]
  [4 4 4]]
 [[5 5 5]
  [6 6 6]]]

data shape is (3, 2, 3)
data rank is 3

indices = np.array([[0, 1], [1, 0]])
indices shape is (2, 2)
indices.shape[-1]= 2 < data rank 3

[0, 1]的索引过程第一个元素是选取data的第一维(aixs=0)的第0个元素得到[[1, 1, 1], [2, 2, 2]],再选取第二维的第1个元素,得到[2, 2, 2]
[1, 0]最后得到[3, 3, 3]
最后得到结果[[2, 2, 2], [3, 3, 3]]

另外一种理解的方式:
最后的切片的结果是indices中表示索引的部分被提取到的值替换后得到的结构。

还是以上面的例子说明这个思路,
indices为 [[0, 1], [1, 0]],其中斜体部分为最后一维即具体索引的部分,
[0, 1]索引得到[2, 2, 2]
[1, 0]索引得到[3, 3, 3]
把索引的结果替换到indices中得到
[[2, 2, 2], [3, 3, 3]]

当索引indices为 [[[[1,1]]]]时,
先找出[1, 1]的索引结果为[4,4,4]
替换到上面结构中得到 [[[[4, 4, 4]]]]

这两个思路,第一个比较正式,严谨但不是那么容易理解和操作
第二个思路理解起来比较简单,容易操作

参考链接:
https://zhuanlan.zhihu.com/p/45673869
https://blog.csdn.net/G66565906/article/details/84949512

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

飞天红猪侠001

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值