tf.where tf.gather tf.gather_nd用法示例

本文介绍了TensorFlow中的tf.where和tf.gather系列函数的使用方法。tf.where根据条件选择元素,当条件为True时返回x对应位置的值,否则返回y的值。tf.gather通过指定索引从张量中提取子集,tf.gather_nd则允许在多维索引中选取元素。通过示例展示了这些函数的用法及其在不同情况下的输出结果。
摘要由CSDN通过智能技术生成

tf.where

tf.where(condition, x=None, y=None, name=None)
# condition, x, y 相同维度,condition是bool型值
# 返回condition中元素为True对应的索引
>>> condition1 = [[True,False,False],
                   [False,True,True]]
[[0 0]
 [1 1]
 [1 2]]
# 如果有 x y 输入,condition为True用x的对应位置替换,为False则用y
# 下例:
import tensorflow as tf
x = [[1,2,3],[4,5,6]]
y = [[7,8,9],[10,11,12]]
condition3 = [[True,False,False],
             [False,True,True]]
condition4 = [[True,False,False],
             [True,True,False]]
with tf.Session() as sess:
    print(sess.run(tf.where(condition3,x,y)))
    print(sess.run(tf.where(condition4,x,y)))  
# 输出:
1[[ 1  8  9]
    [10  5  6]]
2[[ 1  8  9]
    [ 4  5 12]]

tf.gather 和 tf.gather_nd

这俩都是通过索引来切片的方法:

tf.gather(params,indices,axis=0 )
# 从params的axis维根据indices的参数值获取切片

示例:

import numpy as np
import tensorflow as tf


probs = np.array([
    [0, 11, 21, 31, 41, 51, 61, 71, 81],
    [0, 12, 22, 32, 42, 52, 62, 72, 82],
    [0, 13, 23, 33, 43, 53, 63, 73, 83],
    [0, 14, 24, 34, 44, 54, 64, 74, 84]
])

indices_nd = np.array([
    [0, 7],
    [1, 6],
    [2, 6],
    [3, 1]
])

indices_0 = np.array([1, 3])
indices_1 = np.array([7, 3])

with tf.Session() as sess:
    print("tf.gather  axis=0 \n", sess.run(tf.gather(probs, indices_0, axis=0)))
    print("tf.gather  axis=1 \n", sess.run(tf.gather(probs, indices_1, axis=1)))
    print("tf.gather_nd", sess.run(tf.gather_nd(probs, indices_nd)))

输出:

tf.gather  axis=0 
 [[ 0 12 22 32 42 52 62 72 82]
 [ 0 14 24 34 44 54 64 74 84]]

tf.gather  axis=1 
 [[71 31]
 [72 32]
 [73 33]
 [74 34]]

tf.gather_nd [71 62 63 14]
  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
`tf.gather`函数是 TensorFlow 中的一个操作,用于根据索引从张量中收集元素。它的语法如下: ```python tf.gather(params, indices, axis=None, batch_dims=0, name=None) ``` 参数说明: - `params`: 要从中收集元素的张量,可以是任何形状的张量。 - `indices`: 一个张量,指定要收集哪些元素。它可以是任何形状的整数张量。 - `axis`: 一个可选的整数,指定从哪个轴收集元素。默认为`None`,表示要将`indices`解释为一维向量。如果指定`axis`,则`indices`必须是具有相同形状的张量。 - `batch_dims`: 一个可选的整数,指定批次维度的数量。默认为0,表示没有批次维度。例如,如果`params`的形状是`(batch_size, height, width, channels)`,则`batch_dims`可以设置为1,表示`indices`的形状是`(batch_size, num_elements)`。 - `name`: (可选)操作的名称。 `tf.gather`函数的返回值是一个张量,其中包含来自`params`的元素,其索引由`indices`指定。 下面是一个使用`tf.gather`函数的示例代码: ```python import tensorflow as tf # 创建一个3x3的矩阵 x = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # 从第1个轴(即行)收集第0行和第2行 indices = tf.constant([0, 2]) y = tf.gather(x, indices, axis=0) print(y.numpy()) # 输出 [[1 2 3] [7 8 9]] ``` 在这个例子中,我们创建了一个3x3的矩阵`x`,然后使用`tf.gather`函数从第一个轴(即行)收集第0行和第2行,得到了形状为`(2, 3)`的矩阵`y`。注意,`indices`的形状是`(2,)`,因为我们没有指定`axis`参数。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值