TensorFlow2-高阶操作(八):gather/gather_nd(已知元素的位置,从张量中提取该元素)、scatter_nd/scatter_nd_update(已知赋值位置,向0张量中赋值)

一、gather/gather_nd(已知元素的位置,从张量中提取该元素)

1、tf.gather()函数

tf.gather(
		params,					# 传入的tensor
        indices,				# 指定的索引
        validate_indices=None,	# 不重要
        name=None,				# 命名
        axis=None,				# 指定轴
        batch_dims=0)

功能:就是抽取出params的第axis维度上在indices里面所有的index

  • params是要查找的张量,indices是要查找值的索引(int32或int64),axis是查找轴,name是操作名。
  • 如果indices是标量, o u t p u t [ a 0 , . . . , a n , b 0 , . . . , b n ] = p a r a m s [ a 0 , . . . a n , i n d i c e s , b 0 , . . . , b n ] output[a_0,...,a_n,b_0,...,b_n] = params[a_0,...a_n,indices,b_0,...,b_n] output[a0,...,an,b0,...,bn]=params[a0,...an,indices,b0,...,bn]
  • 如果indices是向量, o u t p u t [ a 0 , . . . , a n , i , b 0 , . . . , b n ] = p a r a m s [ a 0 , . . . a n , i n d i c e s [ i ] , b 0 , . . . , b n ] output[a_0,...,a_n,i,b_0,...,b_n] = params[a_0,...a_n,indices[i],b_0,...,b_n] output[a0,...,an,i,b0,...,bn]=params[a0,...an,indices[i],b0,...,bn]
  • 如果indices是高阶张量, o u t p u t [ a 0 , . . . , a n , i , . . . , j , b 0 , . . . , b n ] = p a r a m s [ a 0 , . . . a n , i n d i c e s [ i , . . . , j ] , b 0 , . . . , b n ] output[a_0,...,a_n,i,...,j,b_0,...,b_n] = params[a_0,...a_n,indices[i,...,j],b_0,...,b_n] output[a0,...,an,i,...,j,b0,...,bn]=params[a0,...an,indices[i,...,j],b0,...,bn]

需要注意的是indices里面最大值需要小等于params在指定的axis下ndim的长度。
在这里插入图片描述
如上图所示,params一共6个维度,indices为[2,1,3,4]被取了出来。

该函数返回值类型与params相同,具体值是从params中收集过来的,形状为: p a r a m s . s h a p e [ : a x i s ] + i n d i c e s . s h a p e + p a r a m s . s h a p e [ a x i s + 1 : ] params.shape[:axis]+indices.shape+params.shape[axis+1:] params.shape[:axis]+indices.shape+params.shape[axis+1:]

1.1 indices是标量

import numpy as np
import tensorflow as tf

c1 = tf.constant(np.random.randint(low=1, high=9, size=6))
print("c1 = ", c1)
print("-" * 100)

g1 = tf.gather(c1, indices=2)  # 获取索引为 2 的值
print("g1 = tf.gather(c1, indices=2) = ", g1)
print("-" * 200)

打印结果:

c1 =  tf.Tensor([7 2 8 8 5 4], shape=(6,), dtype=int32)
----------------------------------------------------------------------------------------------------
g1 = tf.gather(c1, indices=2) =  tf.Tensor(8, shape=(), dtype=int32)

1.2 indices是向量

import tensorflow as tf

a = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
index_a = tf.Variable([2, 4, 6, 8])

b = tf.Variable([
    [1, 2, 3, 4, 5],
    [6, 7, 8, 9, 10],
    [11, 12, 13, 14, 15]
])
index_b = tf.Variable([0, 2])

print("a = \n", a)
print("-" * 100)
print("b = \n", b)
print("-" * 200)

g_a = tf.gather(a, indices=index_a)
print("g_a = tf.gather(a, indices=index_a) = ", g_a)
print("-" * 200)

g1 = tf.gather(b, indices=index_b, axis=0)
print("g1 = tf.gather(b, indices=index_b, axis=0) = ", g1)
print("-" * 100)

g2 = tf.gather(b, indices=index_b, axis=1)
print("g2 = tf.gather(b, indices=index_b, axis=1) = ", g2)
print("-" * 200)

打印结果:

a = 
<tf.Variable 'Variable:0' shape=(10,) dtype=int32, numpy=array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])>
----------------------------------------------------------------------------------------------------
b = 
 <tf.Variable 'Variable:0' shape=(3, 5) dtype=int32, numpy=
array([[ 1,  2,  3,  4,  5],
       [ 6,  7,  8,  9, 10],
       [11, 12, 13, 14, 15]])>
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
g_a = tf.gather(a, indices=index_a) =  tf.Tensor([3 5 7 9], shape=(4,), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
g1 = tf.gather(b, indices=index_b, axis=0) =  tf.Tensor(
[[ 1  2  3  4  5]
 [11 12 13 14 15]], shape=(2, 5), dtype=int32)
----------------------------------------------------------------------------------------------------
g2 = tf.gather(b, indices=index_b, axis=1) =  tf.Tensor(
[[ 1  3]
 [ 6  8]
 [11 13]], shape=(3, 2), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Process finished with exit code 0

2、tf.gather_nd()函数

根据定义, 其主要功能是根据indices描述的索引,提取params上的元素, 重新构建一个tensor

tf.gather_nd(
			params,  			# 被收集的张量
			indices, 			# 索引张量。必须是以下类型之一:int32,int64。
			name=None, 	# 操作的名称(可选)
			batch_dims=0
			)

indices是 K K K 阶张量,包含 K − 1 K-1 K1 阶的索引值。它最后一阶是索引,最后一阶维度必须小于等于params的秩。

  • indices最后一阶的维数等于params的秩时,我们得到params的某些元素;
  • indices最后一阶的维数小于params的秩时,我们得到params的切片。

在一维数组中,元素的索引即该元素在数组中序号,通常序号从0开始标记。

如数组 ary=[1,2,3,4]:

  • 元素2的索引 为 1, 元素的引用可表示为 [1];
  • 元素3的索引为 2, 元素的引用可表示为 [2];

那么二维数组呢? 类似地,对于二维 ary=[ [1,2], [3,4] ],

  • 元素 [1,2] 在一维中的索引为 [0],
  • 元素 1 的索引 则表示为 [0,0],
  • 元素 2 的索引 则表示为 [0,1],

因此 gather_nd 实现了根据指定的 参数 indices 来提取params 的元素重建出一个tensor,还是以上面的二维数组为例:

  • [0,0] 表示 的是 1;
  • [0,1] 表示的是 2;

i n d i c e s = [ [ 0 , 0 ] , [ 0 , 1 ] ] indices = [[0,0],[0,1]] indices=[[0,0],[0,1]] 时, 该函数的输出则为 [ 1 , 2 ] [1,2] [1,2],即 indices 中 表示索引的 部分 被提取到的值替换。

那么当indices 为[ [ [ [ [1,1] ] ] ] ] 时 函数输出是什么呢 ? 用元素 替换掉 表示索引的那一部分, 即可得到 [ [ [ [ 4 ] ] ] ]

例如: o u t p u t [ i 0 , . . . , i K − 2 ] = p a r a m s [ i n d i c e s [ i 0 , . . . i K − 2 ] ] \color{blue}{output[i_0,...,i_{K-2}]=params[indices[i_0,...i_{K-2}]]} output[i0,...,iK2]=params[indices[i0,...iK2]]。输出张量的形状由indices的 K − 1 K-1 K1 阶和 params 索引到的形状拼接而成,形状为: i n d i c e s . s h a p e [ : − 1 ] + p a r a m s . s h a p e [ i n d i c e s . s h a p e [ − 1 ] : ] \color{blue}{indices.shape[:-1]+params.shape[indices.shape[-1]:]} indices.shape[:1]+params.shape[indices.shape[1]:]

tf.gather和tf.gather_nd都是从tensor中取出index标注的部分,不同之处在于,gather一般只使用一个index来标注,而gather_nd可以使用多个index。

import tensorflow as tf

params = tf.constant([['a', 'b'], ['c', 'd'], ['e', 'f']])

gather = tf.constant([0, 2])
gather_nd = tf.constant([[0, 0], [1, 1]])

gather_result = tf.gather(params=params, indices=gather)
gather_nd_result = tf.gather_nd(params=params, indices=gather_nd)

print("gather_result = ", gather_result)
print("-" * 200)
print("gather_nd_result = ", gather_nd_result)

打印结果:

gather_result =  tf.Tensor(
[[b'a' b'b']
 [b'e' b'f']], shape=(2, 2), dtype=string)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
gather_nd_result =  tf.Tensor([b'a' b'd'], shape=(2,), dtype=string)

如何直观理解gather_nd的indices呢?

  • 在上例中,直观的理解就是,gather_nd取出params中位于[0,0]和[1,1]处的tensor,放入index中对应的位置。
  • 换句话说,除去tensor维之外,返回值的形状和indices相同,值由indices标注

如果理解了这一点,就可以用gather_nd实现gather:

import tensorflow as tf

params = tf.constant([['a', 'b'], ['c', 'd'], ['e', 'f']])

gather_nd = tf.constant([[0], [2]])
gather_nd_result = tf.gather_nd(params=params, indices=gather_nd)

print("gather_nd_result = ", gather_nd_result)

打印结果:

gather_nd_result =  tf.Tensor(
[[b'a' b'b']
 [b'e' b'f']], shape=(2, 2), dtype=string)

2.1 案例01

import tensorflow as tf

data = tf.constant([[1, 2], [3, 4], [5, 6]])
indices = tf.constant([[1], [0], [1]])

print('data =\n', data)
print("-" * 50)
print('indices =\n', indices)
print("-" * 100)

res = tf.gather_nd(data, indices)

print('res =\n', res)

打印结果:

data =
 tf.Tensor(
[[1 2]
 [3 4]
 [5 6]], shape=(3, 2), dtype=int32)
--------------------------------------------------
indices =
 tf.Tensor(
[[1]
 [0]
 [1]], shape=(3, 1), dtype=int32)
----------------------------------------------------------------------------------------------------
res =
 tf.Tensor(
[[3 4]
 [1 2]
 [3 4]], shape=(3, 2), dtype=int32)

Process finished with exit code 0

2.2 案例02

import tensorflow as tf

data = tf.constant([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
indices = tf.constant([[1, 0], [0, 2], [1, 2]])

print('data =\n', data)
print("-" * 50)
print('indices =\n', indices)
print("-" * 100)

res = tf.gather_nd(data, indices)

print('res =\n', res)

打印结果:

data =
 tf.Tensor(
[[1 2 3]
 [3 4 5]
 [5 6 7]], shape=(3, 3), dtype=int32)
--------------------------------------------------
indices =
 tf.Tensor(
[[1 0]
 [0 2]
 [1 2]], shape=(3, 2), dtype=int32)
----------------------------------------------------------------------------------------------------
res =
 tf.Tensor([3 3 5], shape=(3,), dtype=int32)

Process finished with exit code 0

三、tf.scatter_nd()函数:已知赋值位置,向0张量中赋值

根据indices索引位置将updates中的元素 散布 到新的(初始为零)张量shape中去。

  • 根据索引对给定shape的零张量中的单个值或切片应用稀疏updates来创建新的张量。
  • scatter_nd运算符是 tf.gather_nd 运算符的反函数,tf.gather_nd 运算符是从给定的张量中提取值或切片。

scatter_nd(indices,updates,shape,name=None)

  • indices:一个Tensor;必须是以下类型之一:int32,int64;指数张量。
  • updates:一个Tensor;分散到输出的更新。
  • shape:一个Tensor;必须与indices具有相同的类型;1-d;得到的张量的形状。
  • name:操作的名称(可选)。

警告:更新应用的顺序是非确定性的,所以如果indices包含重复项的话,则输出将是不确定的。

indices是一个整数张量,其中含有索引形成一个新的形状shape张量。indices的最后的维度可以是shape的最多的秩:

indices.shape[-1] <= shape.rank

indices的最后一个维度对应于沿着shape的indices.shape[-1]维度的元素的索引(if indices.shape[-1] = shape.rank)或切片(if indices.shape[-1] < shape.rank)的索引。updates是一个具有如下形状的张量:

indices.shape[:-1] + shape[indices.shape[-1]:]

1、案例01

最简单的分散形式是通过索引将单个元素插入到张量中。例如,假设我们想要在8个元素的1级张量中插入4个分散的元素。
在这里插入图片描述

import tensorflow as tf

indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
shape = tf.constant([8])

output= tf.scatter_nd(indices, updates, shape)
print("output = ", output)

打印结果:

output =  tf.Tensor([ 0 11  0 10  9  0  0 12], shape=(8,), dtype=int32)

2、案例02

我们也可以一次插入一个更高阶张量的整个片。例如,如果我们想要在具有两个新值的矩阵的第三维张量中插入两个切片。
在这里插入图片描述
在Python中,这个分散操作看起来像这样:

import tensorflow as tf

indices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
                        [7, 7, 7, 7], [8, 8, 8, 8]],
                       [[5, 5, 5, 5], [6, 6, 6, 6],
                        [7, 7, 7, 7], [8, 8, 8, 8]]])
shape = tf.constant([4, 4, 4])
print("shape_zero = ", tf.zeros([4, 4, 4]))
print("-" * 200)

output = tf.scatter_nd(indices, updates, shape)
print("output = ", output)

打印结果:

shape_zero =  tf.Tensor(
[[[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]], shape=(4, 4, 4), dtype=float32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
output =  tf.Tensor(
[[[5 5 5 5]
  [6 6 6 6]
  [7 7 7 7]
  [8 8 8 8]]

 [[0 0 0 0]
  [0 0 0 0]
  [0 0 0 0]
  [0 0 0 0]]

 [[5 5 5 5]
  [6 6 6 6]
  [7 7 7 7]
  [8 8 8 8]]

 [[0 0 0 0]
  [0 0 0 0]
  [0 0 0 0]
  [0 0 0 0]]], shape=(4, 4, 4), dtype=int32)

Process finished with exit code 0



参考资料:
[tensorflow] tf.gather使用方法
tf.gather()函数详解
tf.gather( )的用法
tf.gather_nd和tf.gather的区别与联系
TensorFlow中gather, gather_nd, scatter, scatter_nd用法浅析
TensorFlow学习(三):tf.scatter_nd函数
Tensorflow (一): scatter_nd 与 gather_nd
Python tensorflow.scatter_nd方法代码示例
TensorFlow中tf.gather()函数的使用讲解
tf.gather_nd 用法
深度理解tf.gather和tf.gather_nd的用法
Python tensorflow.gather_nd()用法及代码示例

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值