# coding:utf-8
"""
@author: liu
@File: tf_gather_nd.py
@CreateTime: 2021/6/17
"""
"""
测试tf.gather和tf.gather_nd的区别和联系
"""
import numpy as np
import tensorflow as tf
print("tensorflow version: ", tf.__version__)
data = np.arange(18).reshape((3, 2, 3))
print(data)
print(np.ndim(data))
indices = [0, 1]
print(tf.gather(data, indices, axis=0))
print(tf.gather(data, indices, axis=1)) # 设置indices = [1]
print(tf.gather(data, indices, axis=2))
"""
测试tf.gather_nd
"""
print("--------------------------------------------------------------------------------------------")
indices = np.array([[0, 1, 2], [1, 0, 1]])
print(data)
print("===")
print(tf.gather_nd(data, indices)) # tf.Tensor([5 7], shape=(2,), dtype=int32)
"""
定义讲解:
作用:将params索引为indices指定形状的切片数组中(indices代表索引后的数组形状)
indices将切片定义为params的前N个维度,其中N = indices.shape [-1]
通常要求indices.shape[-1] <= params.rank(可以用np.ndim(params)查看)
如果等号成立是在索引具体元素
如果等号不成立是在沿params的indices.shape[-1]轴进行切片
返回维度: indices.shape[:-1] + params.shape[indices.shape[-1]:]
前面的indices.shape[:-1]代表索引后的指定形状
主要讲解indices在params的讲解过程:
其中indices的维度不能超过params的维度np.ndim(params),其中indices的每个元素代表着params的前n个维度对应的切片索引
比如上述的indices[0]= [0, 1, 2], 维度为3, 因此需要在params的前3个维度进行索引,那么在params中进行索引的过程时:data[0][1][2]=5, 获取第一个索引对应的值;同理indices[1] = [1, 0, 1], 获取得到的索引值data[1][0][1]=7
所以tf.gather_nd(data, indices)=[5, 7]
"""
indices = [[0, 1], [1, 0]]
"""
tf.Tensor(
[[3 4 5]
[6 7 8]], shape=(2, 3), dtype=int32)
索引索引:其中n=indices.shape[-1] 代表着索引data的前n个维度,indices中的每个元素代表着data前n个索引对应的索引值
比如indices[0]= [0, 1],维度为2, 代表着索引data的前2个维度, 每个维度依次对应的值为[0, 1], 因此值为data[0][1]= [3, 4, 5]
同理可以得indices[1] = [1, 0], 获取得到的索引值为data[1][0] = [6, 7, 8]
最终结果是:[[3, 4, 5], [6, 7, 8 ]]
"""
print(tf.gather_nd(data, indices=indices))