【tf2函数】tf.gather()

本文详细介绍了TensorFlow中的tf.gather()函数,展示了如何使用该函数从一维和二维张量中按指定索引进行数据切片。通过示例代码解释了该函数在不同维度过滤数据的操作,并给出了相应的输出结果。
摘要由CSDN通过智能技术生成

参考链接传送门

tf.gather():

tf.gather() 单一维度方向的数据,进行任意顺序的切片。

  • 源张量是一维数据

    • 代码:
      import tensorflow as tf
      print("源张量")
      a = tf.constant([0,1,2,3,4,5,6,7,8,9])
      print(a)
       
      b = tf.gather(a, indices=[0,2,4,6,8,1,3,5,7,9])
      print(b)
           
      b = tf.gather(a, indices=[0,2,4])
      print(b)
      
      输出:
      源张量
      tf.Tensor([0 1 2 3 4 5 6 7 8 9], shape=(10,), dtype=int32)
      tf.Tensor([0 2 4 6 8 1 3 5 7 9], shape=(10,), dtype=int32)
      tf.Tensor([0 2 4], shape=(3,), dtype=int32)
      
  • 源张量是二维数据

    • 代码
      import tensorflow as tf
       
      print("原张量")
      a = tf.constant([[0,1,2,3,4],[5,6,7,8,9], [10,11,12,13,14]])
      print(a)
       
      print("\n过滤之后张量(axis=0):")
      b = tf.gather(a, axis=0, indices=[0, 2]) #只能过滤一个维度方向
      print(b)
       
      print("\n过滤之后张量(axis=1):") 
      b = tf.gather(a, axis=1, indices=[1,3]) #只能过滤一个维度方向
      print(b)
      
      输出:
      <tf.Tensor: shape=(3, 5), dtype=int32, numpy=
      array([[ 0,  1,  2,  3,  4],
             [ 5,  6,  7,  8,  9],
             [10, 11, 12, 13, 14]])>
       
      过滤之后张量(axis=0):
      <tf.Tensor: shape=(2, 5), dtype=int32, numpy=
      array([[ 0,  1,  2,  3,  4],
             [10, 11, 12, 13, 14]])>
       
      过滤之后张量(axis=1):
      <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
      array([[ 1,  3],
             [ 6,  8],
             [11, 13]])>
      
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>