tf.gather和tf.gather_nd的使用和区别

# coding:utf-8
"""

@author: liu
@File: tf_gather_nd.py
@CreateTime: 2021/6/17 
"""

"""
测试tf.gather和tf.gather_nd的区别和联系


"""

import numpy as np
import tensorflow as tf
print("tensorflow version: ", tf.__version__)

data = np.arange(18).reshape((3, 2, 3))
print(data)
print(np.ndim(data))


indices = [0, 1]
print(tf.gather(data, indices, axis=0))
print(tf.gather(data, indices, axis=1))  # 设置indices = [1]
print(tf.gather(data, indices, axis=2))


"""
测试tf.gather_nd
"""

print("--------------------------------------------------------------------------------------------")
indices = np.array([[0, 1, 2], [1, 0, 1]])
print(data)
print("===")
print(tf.gather_nd(data, indices))  # tf.Tensor([5 7], shape=(2,), dtype=int32)

"""

定义讲解:
作用:将params索引为indices指定形状的切片数组中(indices代表索引后的数组形状)
indices将切片定义为params的前N个维度,其中N = indices.shape [-1]
通常要求indices.shape[-1] <= params.rank(可以用np.ndim(params)查看)
如果等号成立是在索引具体元素
如果等号不成立是在沿params的indices.shape[-1]轴进行切片
返回维度: indices.shape[:-1] + params.shape[indices.shape[-1]:]
前面的indices.shape[:-1]代表索引后的指定形状

主要讲解indices在params的讲解过程:
其中indices的维度不能超过params的维度np.ndim(params),其中indices的每个元素代表着params的前n个维度对应的切片索引
比如上述的indices[0]= [0, 1, 2], 维度为3, 因此需要在params的前3个维度进行索引,那么在params中进行索引的过程时:data[0][1][2]=5, 获取第一个索引对应的值;同理indices[1] = [1, 0, 1], 获取得到的索引值data[1][0][1]=7
所以tf.gather_nd(data, indices)=[5, 7]



"""

indices = [[0, 1], [1, 0]]
"""
tf.Tensor(
[[3 4 5]
 [6 7 8]], shape=(2, 3), dtype=int32)

索引索引:其中n=indices.shape[-1] 代表着索引data的前n个维度,indices中的每个元素代表着data前n个索引对应的索引值
 比如indices[0]= [0, 1],维度为2, 代表着索引data的前2个维度, 每个维度依次对应的值为[0, 1], 因此值为data[0][1]= [3, 4, 5]
 同理可以得indices[1] = [1, 0], 获取得到的索引值为data[1][0] = [6, 7, 8]

最终结果是:[[3, 4, 5], [6, 7, 8 ]]
 
"""
print(tf.gather_nd(data, indices=indices))








 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值