tf.where tf.gather tf.gather_nd用法示例

本文介绍了TensorFlow中的tf.where和tf.gather系列函数的使用方法。tf.where根据条件选择元素,当条件为True时返回x对应位置的值,否则返回y的值。tf.gather通过指定索引从张量中提取子集,tf.gather_nd则允许在多维索引中选取元素。通过示例展示了这些函数的用法及其在不同情况下的输出结果。
摘要由CSDN通过智能技术生成

tf.where

tf.where(condition, x=None, y=None, name=None)
# condition, x, y 相同维度,condition是bool型值
# 返回condition中元素为True对应的索引
>>> condition1 = [[True,False,False],
                   [False,True,True]]
[[0 0]
 [1 1]
 [1 2]]
# 如果有 x y 输入,condition为True用x的对应位置替换,为False则用y
# 下例:
import tensorflow as tf
x = [[1,2,3],[4,5,6]]
y = [[7,8,9],[10,11,12]]
condition3 = [[True,False,False],
             [False,True,True]]
condition4 = [[True,False,False],
             [True,True,False]]
with tf.Session() as sess:
    print(sess.run(tf.where(condition3,x,y)))
    print(sess.run(tf.where(condition4,x,y)))  
# 输出:
1[[ 1  8  9]
    [10  5  6]]
2[[ 1  8  9]
    [ 4  5 12]]

tf.gather 和 tf.gather_nd

这俩都是通过索引来切片的方法:

tf.gather(params,indices,axis=0 )
# 从params的axis维根据indices的参数值获取切片

示例:

import numpy as np
import tensorflow as tf


probs = np.array([
    [0, 11, 21, 31, 41, 51, 61, 71, 81],
    [0, 12, 22, 32, 42, 52, 62, 72, 82],
    [0, 13, 23, 33, 43, 53, 63, 73, 83],
    [0, 14, 24, 34, 44, 54, 64, 74, 84]
])

indices_nd = np.array([
    [0, 7],
    [1, 6],
    [2, 6],
    [3, 1]
])

indices_0 = np.array([1, 3])
indices_1 = np.array([7, 3])

with tf.Session() as sess:
    print("tf.gather  axis=0 \n", sess.run(tf.gather(probs, indices_0, axis=0)))
    print("tf.gather  axis=1 \n", sess.run(tf.gather(probs, indices_1, axis=1)))
    print("tf.gather_nd", sess.run(tf.gather_nd(probs, indices_nd)))

输出:

tf.gather  axis=0 
 [[ 0 12 22 32 42 52 62 72 82]
 [ 0 14 24 34 44 54 64 74 84]]

tf.gather  axis=1 
 [[71 31]
 [72 32]
 [73 33]
 [74 34]]

tf.gather_nd [71 62 63 14]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值