**
一 tf.gather( ) 函数
**
–tf.gather(params, indices, validate_indices=None, name=None, axis=0)
params:待切片的参数
indices:取出数据所在的位置
axis:指定切片数据所在的维度
tf.gather(a,axis=0,indices=[2,3]): 对参数a进行切片, 对维度0进行操作,将维度0上面的数据切出2,3的数据.
In [1]: import tensorflow as tf
In [2]: a = tf.random.normal([4,35,8],mean=1,stddev=1)
In [3]: a.shape
Out[3]: TensorShape([4, 35, 8])
In [4]: tf.gather(a,axis=0,indices=[2,3]).shape
Out[4]: TensorShape([2, 35, 8])
In [5]: a[2:4].shape
Out[5]: TensorShape([2, 35, 8])
In [6]: tf.gather(a,axis=0,indices=[2,1,3,0]).shape
Out[6]: TensorShape([4, 35, 8])
In [7]: tf.gather(a,axis=1,indices=[2,3,7,9,16]).shape
Out[7]: TensorShape([4, 5, 8])
In [8]: tf.gather(a,axis=2,indices = [2,3,7]).shape
Out[8]: TensorShape([4, 35, 3])
利用tf.gather( ) 函数进行随机打散
In [1]: import tensorflow as tf
In [2]: idx = tf.range(100) #索引
In [3]: idx
Out[3]:
<tf.Tensor: id=3, shape=(100,), dtype=int32, numpy=
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99],
dtype=int32)>
In [4]: idx = tf.random.shuffle(idx)#对索引进行随机打散
In [5]: idx
Out[5]:
<tf.Tensor: id=5, shape=(100,), dtype=int32, numpy=
array([11, 58, 72, 35, 86, 54, 2, 85, 88, 48, 23, 17, 46, 33, 79, 32, 62,
70, 28, 13, 99, 83, 20, 37, 29, 61, 16, 98, 52, 63, 55, 97, 31, 44,
92, 89, 94, 47, 76, 73, 67, 7, 95, 87, 1, 42, 38, 5, 10, 64, 26,
51, 56, 40, 65, 59, 39, 90, 24, 27, 25, 21, 75, 84, 77, 4, 30, 34,
50, 69, 57, 22, 71, 15, 49, 43, 6, 81, 66, 80, 82, 45, 96, 3, 36,
78, 91, 68, 41, 8, 14, 74, 0, 18, 9, 19, 60, 93, 53, 12],
dtype=int32)>
In [6]: a = tf.random.uniform([100,64,64,3],minval=0,maxval=100,dtype=tf.int32)#随机生成100张图片数组
In [7]: y = tf.range(100)
In [8]: y = tf.one_hot(y,depth=100)#标签y
In [9]: y
Out[9]:
<tf.Tensor: id=18, shape=(100, 100), dtype=float32, numpy=
array([[1., 0., 0., ..., 0., 0., 0.],
[0., 1., 0., ..., 0., 0., 0.],
[0., 0., 1., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 1., 0., 0.],
[0., 0., 0., ..., 0., 1., 0.],
[0., 0., 0., ..., 0., 0., 1.]], dtype=float32)>
In [11]: a = tf.gather(a,indices=idx,axis=0)#对数据利用索引进行打散
In [12]: y = tf.gather(y,indices=idx,axis=0)#对输出利用索引进行打散
In [13]: y
Out[13]:
<tf.Tensor: id=23, shape=(100, 100), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>
**
二 tf.gather_nd( )函数
**
在多个维度上进行索引
In [1]: import tensorflow as tf
In [2]: a = tf.random.normal([4,35,8])
In [3]: a.shape
Out[3]: TensorShape([4, 35, 8])
In [4]: tf.gather_nd(a,[0]).shape
Out[4]: TensorShape([35, 8])
In [5]: tf.gather_nd(a,[0,1]).shape
Out[5]: TensorShape([8])
In [6]: tf.gather_nd(a,[0,1,2]).shape
Out[6]: TensorShape([])
In [7]: tf.gather_nd(a,[[0,1,2]]).shape
Out[7]: TensorShape([1])
In [8]: tf.gather_nd(a,[[0,0],[1,1]]).shape
Out[8]: TensorShape([2, 8])
In [9]: tf.gather_nd(a,[[0,0],[1,1],[2,2]]).shape
Out[9]: TensorShape([3, 8])
In [10]: tf.gather_nd(a,[[0,0,0],[1,1,1],[2,2,2]]).shape
Out[10]: TensorShape([3])
In [11]: tf.gather_nd(a,[[[0,0,0],[1,1,1],[2,2,2]]]).shape
Out[11]: TensorShape([1, 3])
**
三 tf.boolean_mask( )函数
**
In [1]: import os
In [2]: import tensorflow as tf
In [3]: os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2'
In [4]: a = tf.random.normal([4,28,28,3],mean=1,stddev=1)
In [5]: a.shape
Out[5]: TensorShape([4, 28, 28, 3])
In [6]: tf.boolean_mask(a,mask=[True,True,False,False],axis=0).shape
Out[6]: TensorShape([2, 28, 28, 3])
In [7]: tf.boolean_mask(a,mask=[True,True,False],axis=3).shape
Out[7]: TensorShape([4, 28, 28, 2])
In [8]: a = tf.ones([2,3,4])
In [9]: a.shape
Out[9]: TensorShape([2, 3, 4])
In [10]: tf.boolean_mask(a,mask = [[True,False,False],[False,True,True]])
Out[10]:
<tf.Tensor: id=92, shape=(3, 4), dtype=float32, numpy=
array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], dtype=float32)>