一、gather/gather_nd(已知元素的位置,从张量中提取该元素)
1、tf.gather()函数
tf.gather(
params, # 传入的tensor
indices, # 指定的索引
validate_indices=None, # 不重要
name=None, # 命名
axis=None, # 指定轴
batch_dims=0)
功能:就是抽取出params的第axis维度上在indices里面所有的index
- params是要查找的张量,indices是要查找值的索引(int32或int64),axis是查找轴,name是操作名。
- 如果indices是标量, o u t p u t [ a 0 , . . . , a n , b 0 , . . . , b n ] = p a r a m s [ a 0 , . . . a n , i n d i c e s , b 0 , . . . , b n ] output[a_0,...,a_n,b_0,...,b_n] = params[a_0,...a_n,indices,b_0,...,b_n] output[a0,...,an,b0,...,bn]=params[a0,...an,indices,b0,...,bn];
- 如果indices是向量, o u t p u t [ a 0 , . . . , a n , i , b 0 , . . . , b n ] = p a r a m s [ a 0 , . . . a n , i n d i c e s [ i ] , b 0 , . . . , b n ] output[a_0,...,a_n,i,b_0,...,b_n] = params[a_0,...a_n,indices[i],b_0,...,b_n] output[a0,...,an,i,b0,...,bn]=params[a0,...an,indices[i],b0,...,bn];
- 如果indices是高阶张量, o u t p u t [ a 0 , . . . , a n , i , . . . , j , b 0 , . . . , b n ] = p a r a m s [ a 0 , . . . a n , i n d i c e s [ i , . . . , j ] , b 0 , . . . , b n ] output[a_0,...,a_n,i,...,j,b_0,...,b_n] = params[a_0,...a_n,indices[i,...,j],b_0,...,b_n] output[a0,...,an,i,...,j,b0,...,bn]=params[a0,...an,indices[i,...,j],b0,...,bn]
需要注意的是indices里面最大值需要小等于params在指定的axis下ndim的长度。
如上图所示,params一共6个维度,indices为[2,1,3,4]被取了出来。
该函数返回值类型与params相同,具体值是从params中收集过来的,形状为: p a r a m s . s h a p e [ : a x i s ] + i n d i c e s . s h a p e + p a r a m s . s h a p e [ a x i s + 1 : ] params.shape[:axis]+indices.shape+params.shape[axis+1:] params.shape[:axis]+indices.shape+params.shape[axis+1:]
1.1 indices是标量
import numpy as np
import tensorflow as tf
c1 = tf.constant(np.random.randint(low=1, high=9, size=6))
print("c1 = ", c1)
print("-" * 100)
g1 = tf.gather(c1, indices=2) # 获取索引为 2 的值
print("g1 = tf.gather(c1, indices=2) = ", g1)
print("-" * 200)
打印结果:
c1 = tf.Tensor([7 2 8 8 5 4], shape=(6,), dtype=int32)
----------------------------------------------------------------------------------------------------
g1 = tf.gather(c1, indices=2) = tf.Tensor(8, shape=(), dtype=int32)
1.2 indices是向量
import tensorflow as tf
a = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
index_a = tf.Variable([2, 4, 6, 8])
b = tf.Variable([
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
[11, 12, 13, 14, 15]
])
index_b = tf.Variable([0, 2])
print("a = \n", a)
print("-" * 100)
print("b = \n", b)
print("-" * 200)
g_a = tf.gather(a, indices=index_a)
print("g_a = tf.gather(a, indices=index_a) = ", g_a)
print("-" * 200)
g1 = tf.gather(b, indices=index_b, axis=0)
print("g1 = tf.gather(b, indices=index_b, axis=0) = ", g1)
print("-" * 100)
g2 = tf.gather(b, indices=index_b, axis=1)
print("g2 = tf.gather(b, indices=index_b, axis=1) = ", g2)
print("-" * 200)
打印结果:
a =
<tf.Variable 'Variable:0' shape=(10,) dtype=int32, numpy=array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])>
----------------------------------------------------------------------------------------------------
b =
<tf.Variable 'Variable:0' shape=(3, 5) dtype=int32, numpy=
array([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15]])>
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
g_a = tf.gather(a, indices=index_a) = tf.Tensor([3 5 7 9], shape=(4,), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
g1 = tf.gather(b, indices=index_b, axis=0) = tf.Tensor(
[[ 1 2 3 4 5]
[11 12 13 14 15]], shape=(2, 5), dtype=int32)
----------------------------------------------------------------------------------------------------
g2 = tf.gather(b, indices=index_b, axis=1) = tf.Tensor(
[[ 1 3]
[ 6 8]
[11 13]], shape=(3, 2), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Process finished with exit code 0
2、tf.gather_nd()函数
根据定义, 其主要功能是根据indices描述的索引,提取params上的元素, 重新构建一个tensor
tf.gather_nd(
params, # 被收集的张量
indices, # 索引张量。必须是以下类型之一:int32,int64。
name=None, # 操作的名称(可选)
batch_dims=0
)
indices是 K K K 阶张量,包含 K − 1 K-1 K−1 阶的索引值。它最后一阶是索引,最后一阶维度必须小于等于params的秩。
- indices最后一阶的维数等于params的秩时,我们得到params的某些元素;
- indices最后一阶的维数小于params的秩时,我们得到params的切片。
在一维数组中,元素的索引即该元素在数组中序号,通常序号从0开始标记。
如数组 ary=[1,2,3,4]:
- 元素2的索引 为 1, 元素的引用可表示为 [1];
- 元素3的索引为 2, 元素的引用可表示为 [2];
那么二维数组呢? 类似地,对于二维 ary=[ [1,2], [3,4] ],
- 元素 [1,2] 在一维中的索引为 [0],
- 元素 1 的索引 则表示为 [0,0],
- 元素 2 的索引 则表示为 [0,1],
因此 gather_nd 实现了根据指定的 参数 indices 来提取params 的元素重建出一个tensor,还是以上面的二维数组为例:
- [0,0] 表示 的是 1;
- [0,1] 表示的是 2;
当 i n d i c e s = [ [ 0 , 0 ] , [ 0 , 1 ] ] indices = [[0,0],[0,1]] indices=[[0,0],[0,1]] 时, 该函数的输出则为 [ 1 , 2 ] [1,2] [1,2],即 indices 中 表示索引的 部分 被提取到的值替换。
那么当indices 为[ [ [ [ [1,1] ] ] ] ] 时 函数输出是什么呢 ? 用元素 替换掉 表示索引的那一部分, 即可得到 [ [ [ [ 4 ] ] ] ]
例如: o u t p u t [ i 0 , . . . , i K − 2 ] = p a r a m s [ i n d i c e s [ i 0 , . . . i K − 2 ] ] \color{blue}{output[i_0,...,i_{K-2}]=params[indices[i_0,...i_{K-2}]]} output[i0,...,iK−2]=params[indices[i0,...iK−2]]。输出张量的形状由indices的 K − 1 K-1 K−1 阶和 params 索引到的形状拼接而成,形状为: i n d i c e s . s h a p e [ : − 1 ] + p a r a m s . s h a p e [ i n d i c e s . s h a p e [ − 1 ] : ] \color{blue}{indices.shape[:-1]+params.shape[indices.shape[-1]:]} indices.shape[:−1]+params.shape[indices.shape[−1]:]
tf.gather和tf.gather_nd都是从tensor中取出index标注的部分,不同之处在于,gather一般只使用一个index来标注,而gather_nd可以使用多个index。
import tensorflow as tf
params = tf.constant([['a', 'b'], ['c', 'd'], ['e', 'f']])
gather = tf.constant([0, 2])
gather_nd = tf.constant([[0, 0], [1, 1]])
gather_result = tf.gather(params=params, indices=gather)
gather_nd_result = tf.gather_nd(params=params, indices=gather_nd)
print("gather_result = ", gather_result)
print("-" * 200)
print("gather_nd_result = ", gather_nd_result)
打印结果:
gather_result = tf.Tensor(
[[b'a' b'b']
[b'e' b'f']], shape=(2, 2), dtype=string)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
gather_nd_result = tf.Tensor([b'a' b'd'], shape=(2,), dtype=string)
如何直观理解gather_nd的indices呢?
- 在上例中,直观的理解就是,gather_nd取出params中位于[0,0]和[1,1]处的tensor,放入index中对应的位置。
- 换句话说,除去tensor维之外,返回值的形状和indices相同,值由indices标注。
如果理解了这一点,就可以用gather_nd实现gather:
import tensorflow as tf
params = tf.constant([['a', 'b'], ['c', 'd'], ['e', 'f']])
gather_nd = tf.constant([[0], [2]])
gather_nd_result = tf.gather_nd(params=params, indices=gather_nd)
print("gather_nd_result = ", gather_nd_result)
打印结果:
gather_nd_result = tf.Tensor(
[[b'a' b'b']
[b'e' b'f']], shape=(2, 2), dtype=string)
2.1 案例01
import tensorflow as tf
data = tf.constant([[1, 2], [3, 4], [5, 6]])
indices = tf.constant([[1], [0], [1]])
print('data =\n', data)
print("-" * 50)
print('indices =\n', indices)
print("-" * 100)
res = tf.gather_nd(data, indices)
print('res =\n', res)
打印结果:
data =
tf.Tensor(
[[1 2]
[3 4]
[5 6]], shape=(3, 2), dtype=int32)
--------------------------------------------------
indices =
tf.Tensor(
[[1]
[0]
[1]], shape=(3, 1), dtype=int32)
----------------------------------------------------------------------------------------------------
res =
tf.Tensor(
[[3 4]
[1 2]
[3 4]], shape=(3, 2), dtype=int32)
Process finished with exit code 0
2.2 案例02
import tensorflow as tf
data = tf.constant([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
indices = tf.constant([[1, 0], [0, 2], [1, 2]])
print('data =\n', data)
print("-" * 50)
print('indices =\n', indices)
print("-" * 100)
res = tf.gather_nd(data, indices)
print('res =\n', res)
打印结果:
data =
tf.Tensor(
[[1 2 3]
[3 4 5]
[5 6 7]], shape=(3, 3), dtype=int32)
--------------------------------------------------
indices =
tf.Tensor(
[[1 0]
[0 2]
[1 2]], shape=(3, 2), dtype=int32)
----------------------------------------------------------------------------------------------------
res =
tf.Tensor([3 3 5], shape=(3,), dtype=int32)
Process finished with exit code 0
三、tf.scatter_nd()函数:已知赋值位置,向0张量中赋值
根据indices索引位置将updates中的元素 散布 到新的(初始为零)张量shape中去。
- 根据索引对给定shape的零张量中的单个值或切片应用稀疏updates来创建新的张量。
- scatter_nd运算符是 tf.gather_nd 运算符的反函数,tf.gather_nd 运算符是从给定的张量中提取值或切片。
scatter_nd(indices,updates,shape,name=None)
- indices:一个Tensor;必须是以下类型之一:int32,int64;指数张量。
- updates:一个Tensor;分散到输出的更新。
- shape:一个Tensor;必须与indices具有相同的类型;1-d;得到的张量的形状。
- name:操作的名称(可选)。
警告:更新应用的顺序是非确定性的,所以如果indices包含重复项的话,则输出将是不确定的。
indices是一个整数张量,其中含有索引形成一个新的形状shape张量。indices的最后的维度可以是shape的最多的秩:
indices.shape[-1] <= shape.rank
indices的最后一个维度对应于沿着shape的indices.shape[-1]维度的元素的索引(if indices.shape[-1] = shape.rank)或切片(if indices.shape[-1] < shape.rank)的索引。updates是一个具有如下形状的张量:
indices.shape[:-1] + shape[indices.shape[-1]:]
1、案例01
最简单的分散形式是通过索引将单个元素插入到张量中。例如,假设我们想要在8个元素的1级张量中插入4个分散的元素。
import tensorflow as tf
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
shape = tf.constant([8])
output= tf.scatter_nd(indices, updates, shape)
print("output = ", output)
打印结果:
output = tf.Tensor([ 0 11 0 10 9 0 0 12], shape=(8,), dtype=int32)
2、案例02
我们也可以一次插入一个更高阶张量的整个片。例如,如果我们想要在具有两个新值的矩阵的第三维张量中插入两个切片。
在Python中,这个分散操作看起来像这样:
import tensorflow as tf
indices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]])
shape = tf.constant([4, 4, 4])
print("shape_zero = ", tf.zeros([4, 4, 4]))
print("-" * 200)
output = tf.scatter_nd(indices, updates, shape)
print("output = ", output)
打印结果:
shape_zero = tf.Tensor(
[[[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]]
[[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]]
[[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]]
[[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]]], shape=(4, 4, 4), dtype=float32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
output = tf.Tensor(
[[[5 5 5 5]
[6 6 6 6]
[7 7 7 7]
[8 8 8 8]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]
[[5 5 5 5]
[6 6 6 6]
[7 7 7 7]
[8 8 8 8]]
[[0 0 0 0]
[0 0 0 0]
[0 0 0 0]
[0 0 0 0]]], shape=(4, 4, 4), dtype=int32)
Process finished with exit code 0
参考资料:
[tensorflow] tf.gather使用方法
tf.gather()函数详解
tf.gather( )的用法
tf.gather_nd和tf.gather的区别与联系
TensorFlow中gather, gather_nd, scatter, scatter_nd用法浅析
TensorFlow学习(三):tf.scatter_nd函数
Tensorflow (一): scatter_nd 与 gather_nd
Python tensorflow.scatter_nd方法代码示例
TensorFlow中tf.gather()函数的使用讲解
tf.gather_nd 用法
深度理解tf.gather和tf.gather_nd的用法
Python tensorflow.gather_nd()用法及代码示例