TensorFlow2.0:索引和切片(2)

**

一 tf.gather( ) 函数

**
–tf.gather(params, indices, validate_indices=None, name=None, axis=0)
params:待切片的参数
indices:取出数据所在的位置
axis:指定切片数据所在的维度
tf.gather(a,axis=0,indices=[2,3]): 对参数a进行切片, 对维度0进行操作,将维度0上面的数据切出2,3的数据.

In [1]: import tensorflow as tf                                                 

In [2]: a = tf.random.normal([4,35,8],mean=1,stddev=1)                          

In [3]: a.shape                                                                 
Out[3]: TensorShape([4, 35, 8])

In [4]: tf.gather(a,axis=0,indices=[2,3]).shape                                 
Out[4]: TensorShape([2, 35, 8])

In [5]: a[2:4].shape                                                            
Out[5]: TensorShape([2, 35, 8])

In [6]: tf.gather(a,axis=0,indices=[2,1,3,0]).shape                             
Out[6]: TensorShape([4, 35, 8])

In [7]: tf.gather(a,axis=1,indices=[2,3,7,9,16]).shape                          
Out[7]: TensorShape([4, 5, 8])

In [8]: tf.gather(a,axis=2,indices = [2,3,7]).shape                             
Out[8]: TensorShape([4, 35, 3])

利用tf.gather( ) 函数进行随机打散

In [1]: import tensorflow as tf                                                 

In [2]: idx = tf.range(100) #索引                                                    

In [3]: idx                                                                     
Out[3]: 
<tf.Tensor: id=3, shape=(100,), dtype=int32, numpy=
array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
       68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
       85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99],
      dtype=int32)>

In [4]: idx = tf.random.shuffle(idx)#对索引进行随机打散                                            

In [5]: idx                                                                     
Out[5]: 
<tf.Tensor: id=5, shape=(100,), dtype=int32, numpy=
array([11, 58, 72, 35, 86, 54,  2, 85, 88, 48, 23, 17, 46, 33, 79, 32, 62,
       70, 28, 13, 99, 83, 20, 37, 29, 61, 16, 98, 52, 63, 55, 97, 31, 44,
       92, 89, 94, 47, 76, 73, 67,  7, 95, 87,  1, 42, 38,  5, 10, 64, 26,
       51, 56, 40, 65, 59, 39, 90, 24, 27, 25, 21, 75, 84, 77,  4, 30, 34,
       50, 69, 57, 22, 71, 15, 49, 43,  6, 81, 66, 80, 82, 45, 96,  3, 36,
       78, 91, 68, 41,  8, 14, 74,  0, 18,  9, 19, 60, 93, 53, 12],
      dtype=int32)>

In [6]: a = tf.random.uniform([100,64,64,3],minval=0,maxval=100,dtype=tf.int32)#随机生成100张图片数组 

In [7]: y = tf.range(100)                                                       

In [8]: y = tf.one_hot(y,depth=100)#标签y                                             

In [9]: y                                                                       
Out[9]: 
<tf.Tensor: id=18, shape=(100, 100), dtype=float32, numpy=
array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]], dtype=float32)>

In [11]: a = tf.gather(a,indices=idx,axis=0)#对数据利用索引进行打散                                    

In [12]: y = tf.gather(y,indices=idx,axis=0)#对输出利用索引进行打散                                    

In [13]: y                                                                      
Out[13]: 
<tf.Tensor: id=23, shape=(100, 100), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>

**

二 tf.gather_nd( )函数

**
在多个维度上进行索引

In [1]: import tensorflow as tf                                                 

In [2]: a = tf.random.normal([4,35,8])                                          

In [3]: a.shape                                                                 
Out[3]: TensorShape([4, 35, 8])

In [4]: tf.gather_nd(a,[0]).shape                                               
Out[4]: TensorShape([35, 8])

In [5]: tf.gather_nd(a,[0,1]).shape                                             
Out[5]: TensorShape([8])

In [6]: tf.gather_nd(a,[0,1,2]).shape                                           
Out[6]: TensorShape([])

In [7]: tf.gather_nd(a,[[0,1,2]]).shape                                         
Out[7]: TensorShape([1])
In [8]: tf.gather_nd(a,[[0,0],[1,1]]).shape                                     
Out[8]: TensorShape([2, 8])

In [9]: tf.gather_nd(a,[[0,0],[1,1],[2,2]]).shape                               
Out[9]: TensorShape([3, 8])

In [10]: tf.gather_nd(a,[[0,0,0],[1,1,1],[2,2,2]]).shape                        
Out[10]: TensorShape([3])

In [11]: tf.gather_nd(a,[[[0,0,0],[1,1,1],[2,2,2]]]).shape                      
Out[11]: TensorShape([1, 3])

**

三 tf.boolean_mask( )函数

**

In [1]: import os                                                               

In [2]: import tensorflow as tf                                                 

In [3]: os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2'                                

In [4]: a = tf.random.normal([4,28,28,3],mean=1,stddev=1)                       

In [5]: a.shape                                                                 
Out[5]: TensorShape([4, 28, 28, 3])

In [6]: tf.boolean_mask(a,mask=[True,True,False,False],axis=0).shape            

Out[6]: TensorShape([2, 28, 28, 3])

In [7]: tf.boolean_mask(a,mask=[True,True,False],axis=3).shape                  
Out[7]: TensorShape([4, 28, 28, 2])

In [8]: a = tf.ones([2,3,4])                                                    

In [9]: a.shape                                                                 
Out[9]: TensorShape([2, 3, 4])

In [10]: tf.boolean_mask(a,mask = [[True,False,False],[False,True,True]])       
Out[10]: 
<tf.Tensor: id=92, shape=(3, 4), dtype=float32, numpy=
array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]], dtype=float32)>

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值