tf.gather,tf.range()的详解

在讲解这个之前,我们首先讲一下tf.range(),因为这两个一般都是在一起用的

tf.range()

其和python中的range()的用法基本一样,只不过这里返回的是一个1-D的tensor
tf.range(limit, delta=1, dtype=None, name=‘range’)
tf.range(start, limit, delta=1, dtype=None, name=‘range’)

'''
Args:
	start: A 0-D Tensor (scalar). Acts as first entry in the range if limit is not None; otherwise, acts as range limit and first entry defaults to 0.
	limit: A 0-D Tensor (scalar). Upper limit of sequence, exclusive. If None, defaults to the value of start while the first entry of the range defaults to 0.
	delta: A 0-D Tensor (scalar). Number that increments start. Defaults to 1.
	dtype: The type of the elements of the resulting tensor.
	name: A name for the operation. Defaults to "range".
Returns:
	An 1-D Tensor of type dtype.
'''

tf.gather

该接口的作用:就是抽取出params的第axis维度上在indices里面所有的index(看后面的例子,就会懂)

tf.gather(
    params,
    indices,
    validate_indices=None,
    name=None,
    axis=0
)

'''
Args:
	params: A Tensor. The tensor from which to gather values. Must be at least rank axis + 1.
	indices: A Tensor. Must be one of the following types: int32, int64. Index tensor. Must be in range [0, params.shape[axis]).
	axis: A Tensor. Must be one of the following types: int32, int64. The axis in params to gather indices from. Defaults to the first dimension. Supports negative indexes.
	name: A name for the operation (optional).
Returns:
	A Tensor. Has the same type as params.
'''
说明
参数
  • params: A Tensor.
  • indices: A Tensor. types必须是: int32, int64. 里面的每一个元素大小必须在 [0, params.shape[axis])范围内.
  • axis: 维度。沿着params的哪一个维度进行抽取indices
返回

返回的是一个tensor

帮助理解图

在这里插入图片描述

例子1
代码
import tensorflow as tf

Params = tf.range(0,10)*10
a = tf.gather(Params,[0,5,9])

with tf.Session() as sess:
    print("Params:  \n",sess.run(Params))
    print("抽取的结果: \n",sess.run(a))


输出
Params:  
        [ 0 10 20 30 40 50 60 70 80 90]
抽取的结果: 
      [ 0 50 90]
例子2
代码
import tensorflow as tf
Params=tf.Variable(tf.random_normal([2,3,4]))
indicxs_0=[0,1]
indicxs_1=[0,2]
indicxs_2=[2,3]

gather_0=tf.gather(params=Params,indices=indicxs_0,axis=0)
gather_1=tf.gather(params=Params,indices=indicxs_1,axis=1)
gather_2=tf.gather(params=Params,indices=indicxs_2,axis=2)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("Params :\n      ",sess.run(Params))
    print("沿着第O维度抽取第0,1个:    \n",sess.run(gather_0))
    print("沿着第1维度抽取第0,3个:    \n",sess.run(gather_1))
    print("沿着第2维度抽取第3,4个:    \n",sess.run(gather_2))

输出
Params :
       [
       [[ 0.78150964  2.09648061  2.37558031  1.20743346]
		[-1.12413085 -0.66349769  1.15486336 -1.17151475]
		 [ 0.0476133  -0.09292984 -0.29620713  0.70557141]]

		 [[-1.34968698 -0.2931003  -1.94950449 -0.27036974]
		  [ 0.27591622 -0.19094539 -0.56113148  0.55863774]
		  [-0.48273012 -0.7819376   0.3261987  -0.97833097]]
		  ]
沿着第O维度抽取第0,1个:    
 [[[ 0.78150964  2.09648061  2.37558031  1.20743346]
  [-1.12413085 -0.66349769  1.15486336 -1.17151475]
  [ 0.0476133  -0.09292984 -0.29620713  0.70557141]]

 [[-1.34968698 -0.2931003  -1.94950449 -0.27036974]
  [ 0.27591622 -0.19094539 -0.56113148  0.55863774]
  [-0.48273012 -0.7819376   0.3261987  -0.97833097]]]
沿着第1维度抽取第0,3个:    
 [[[ 0.78150964  2.09648061  2.37558031  1.20743346]
  [ 0.0476133  -0.09292984 -0.29620713  0.70557141]]

 [[-1.34968698 -0.2931003  -1.94950449 -0.27036974]
  [-0.48273012 -0.7819376   0.3261987  -0.97833097]]]
沿着第2维度抽取第3,4个:    
 [[[ 2.37558031  1.20743346]
  [ 1.15486336 -1.17151475]
  [-0.29620713  0.70557141]]

 [[-1.94950449 -0.27036974]
  [-0.56113148  0.55863774]
  [ 0.3261987  -0.97833097]]]
### TensorFlow `tf.gather` 的使用方法 `tf.gather` 是 TensorFlow 提供的一个用于从张量中提取子集的功能强大的工具。它允许通过指定索引来获取张量的部分内容,类似于 NumPy 的高级索引功能。 #### 参数说明 以下是 `tf.gather` 的主要参数及其作用: - **params**: 输入的张量,从中提取数据。 - **indices**: 表示要提取的数据的位置索引,可以是一维或多维数组。 - **axis**: 指定沿哪个轴进行聚集操作,默认为 0。 - **batch_dims**: 可选参数,表示前多少维度作为批处理维度,默认为 0。 - **name**: 运算名称(可选)。 #### 示例代码 下面是一个简单的例子展示如何使用 `tf.gather` 来从张量中选取特定部分的内容。 ```python import tensorflow as tf # 定义输入张量 tensor = tf.constant([[1, 2, 3], [4, 5, 6]]) # 定义索引 indices = tf.constant([0, 1]) # 默认沿着第一个轴 (axis=0) 获取元素 result_axis_0 = tf.gather(tensor, indices) # 修改 axis 参数来改变行为 result_axis_1 = tf.gather(tensor, indices, axis=1) print("Original Tensor:") print(tensor.numpy()) print("\nGathering along axis 0:") print(result_axis_0.numpy()) # 输出 [[1, 2, 3], [4, 5, 6]] print("\nGathering along axis 1:") print(result_axis_1.numpy()) # 输出 [[1, 2], [4, 5]] ``` 上述代码展示了两种不同的方式调用 `tf.gather` 方法:一种是在默认情况下按行收集;另一种则是按照列的方向进行收集[^1]。 #### 注意事项 当尝试优化性能时,建议尽量减少 Python 控制流语句的影响,并采用 TensorFlow 原生支持的操作替代标准库实现。例如,在涉及布尔值初始化的地方应该优先考虑使用 `tf.constant(True)` 而不是普通的 Python 布尔值 True。 另外需要注意的是,尽管可以在动态模式下自由混合使用常规 Python 和 TensorFlow API 编写逻辑,但如果希望利用自动图技术,则需严格遵守其编码准则以确保兼容性和效率最大化[^3]。 #### 关联文档链接 对于更详细的官方描述以及额外选项解释,请访问 TensorFlow GitHub 发布页面上的具体版本更新日志[^2] 或者查阅最新版 TensorFlow 文档关于此函数的具体章节。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值