主要讲Tensor数据的索引和切片
import tensorflow as tf
a = tf.ones([1,5,5,3])
#通用
a[0][0]
a[0][0][2]
#
a= tf.random.normal([4,28,28,5])
a[1].shape
TensorShape([28, 28, 5])
a[1,2].shape
TensorShape([28, 5])
a[1,2,3].shape
TensorShape([5])
#**start:end**
a[-1:].shape
TensorShape([1, 28, 28, 5])
a[-2:].shape
TensorShape([2, 28, 28, 5])
a[:2].shape
TensorShape([2, 28, 28, 5])
a[:-1].shape
TensorShape([3, 28, 28, 5])
a[0,1,:,4].shape
TensorShape([28])
#start:end:step
#::step
a[0:2,:,:,:].shape
TensorShape([2, 28, 28, 5])
a[:,::2,::2,:].shape
#倒序
#start:end:-1
a[::-1].shape
TensorShape([4, 28, 28, 5])
a[::-2].shape
TensorShape([2, 28, 28, 5])
#...代表任意长的冒号:
a[0,...,2].shape
TensorShape([28, 28])
a[1,0,...,0].shape
TensorShape([28])
Selective Indexing
#**tf.gather**
#data:[classes,students,subjects]
a = tf.random.normal([4,35,8])
tf.gather(a, axis=0,indices = [2,3]).shape
TensorShape([2, 35, 8])
tf.gather(a, axis=1,indices = [2,7,3,9]).shape
TensorShape([4, 4, 8])
tf.gather(a, axis=2,indices = [2,7,3]).shape
TensorShape([4, 35, 3])
#tf.gather_nd
tf.gather_nd(a,[0]).shape
TensorShape([35, 8])
tf.gather_nd(a,[0,1]).shape
TensorShape([8])
tf.gather_nd(a,[[0,1,2],[0,2,3]]).shape
TensorShape([2])
#**tf.boolean_mask**
a = tf.random.normal([4,35,8,3])
tf.boolean_mask(a,mask=[True,True,False,False]).shape
TensorShape([2, 35, 8,3])
tf.boolean_mask(a,mask=[True,True,False],axis=3).shape
TensorShape([4, 35, 8, 2])