Tensorflow(三) —— Tensor的索引与切片
1 主要的几种索引方式
- 1、basic indexing
- 2、same with numpy
- 3、gather
- 4、gather_nd
- 5、boolean_mask
2 basic indexing
data = tf.random.normal([4,28,28,3])
print(data[0].shape) # 取出第一张照片
print(data[0][0].shape) # 取出第一张照片的第一行
print(data[0][0][0].shape) # 取出第一张照片的第一行第一列的三个rgb通道
print(data[0][0][0][0].shape) # 取出第一张照片的第一行第一列的三个r通道 即为一个标量
3 numpy_style indexing
a = tf.random.uniform([4,28,28,3])
print(data[0].shape) # 取出第一张照片
print(data[0,0].shape) # 取出第一张照片的第一行
print(data[0,0,0].shape) # 取出第一张照片的第一行第一列的三个rgb通道
print(data[0,0,0,0].shape) # 取出第一张照片的第一行第一列的三个r通道 即为一个标量
4 start:end
a = tf.ones([4,28,28,3])
# 返回最后一个元素构成的向量 切片总会返回一个向量
print(a[0,0,0,-1:].shape)
# 返回最后一个元素构成的向量 切片总会返回一个向量
print(a[0,0,0,:1].shape)
# 返回倒数二个元素
print(a[0,0,0,-2:].shape)
# 返回前两个元素
print(a[0,0,0,:2].shape)
# 不包含最后一个元素
print(a[0,0,0,:-1].shape)
print("*"*18)
print(a[0,:,:,:].shape) # 取出第一张照片
print(a[0,0,:,:].shape) # 取出第一张照片的第一行
print(a[0,0,0,:].shape) # 取出第一张照片的第一行第一列的三个rgb通道
print(a[0,0,0,0].shape) # 取出第一张照片的第一行第一列的三个r通道 即为一个标量
print(a[:,:,:,1].shape) # 取出所有图片的g通道
print(a[:,1,:,:].shape)# 所有照片的一行的rgb通道
5 start🔚step
import numpy as np
data = np.random.rand(4,28,28,3)
a = tf.convert_to_tensor(data)
# 返回每张照片隔一行和隔一列的rgb通道
print(a[:,::2,::2,:].shape)
# 返回每张照片的前14行和14列的rgb通道
print(a[:,:14,:14,:].shape)
# 返回每张照片的后14行和14列的rgb通道
print(a[:,14:,14:,:].shape)
6 ::-1 实现倒序功能
a = tf.range(10)
print(a)
"""
当步长为负数时,切片的左端代表最末端,右端代表最首位
"""
# 实现逆序
print(a[::-1])
print(a[::-2]) # 从尾至首 以二为步长索引
print(a[2::-2]) # 从3号元素开始逆序索引
7 … 代表任意长
a = tf.random.uniform([10,100,28,28,3],maxval = 100,minval= 1)
# 索引第一个任务的数据
print(a[0,:,:,:,:].shape)
print(a[0,...].shape)
# 索引所有任务的r通道
print(a[...,0].shape)
# 索引第一个任务的b通道
print(a[0,...,2].shape)
# 索引第二个任务所有图片第三列的rgb通道
print(a[0,...,2,:].shape)
8 selective indexing(可选索引)
"""
非连续索引
"""
# ******************** gather索引
a = tf.random.normal([4,28,28,3])
# 获取第2和第四张图片的所有rgb通道
print(tf.gather(a,axis=0,indices = [1,3]).shape)
# 获取所有图片第2,5,7,9行的rgb通道
print(tf.gather(a,axis=1,indices = [1,4,6,8]).shape)
# 获取所有图片的r和b通道
print(tf.gather(a,axis = 3,indices = [0,2]).shape)
"""
gather方法另外一个主要的用法就是可以按给定索引顺序进行索引并返回
"""
b = tf.range(10)
print(b)
index = tf.random.shuffle(b)
print(tf.gather(b,axis = 0,indices = index))
9 gather_nd索引
"""
组合各个索引值
"""
a = tf.random.normal([4,28,28,3])
# 取单个值
print(tf.gather_nd(a,[0,0,0,0]).shape)
# 取单个值作为向量
print(tf.gather_nd(a,[[0,0,0,0]]).shape)
# 取不同图片不同行的rgb通道
print(tf.gather_nd(a,[[0,1,0,2],[1,2,5,1]]).shape)
print(tf.gather_nd(a,[[[0,0,0,0],[1,2,5,2],[2,6,8,1]]]).shape)
"""
加[]代表扩维
"""
10 tf.boolean_mask
"""
%save用bool尔值进行索引
"""
a = tf.random.normal([4,28,28,3])
# 索引第1和第三张图片
print(tf.boolean_mask(a,[True,False,True,False]).shape)
# 索引r通道
print(tf.boolean_mask(a,axis = 3,mask= [True,False,False]).shape)
# 索引某行列对应的值
b = tf.ones([2,3,4])
print(tf.boolean_mask(b,mask = [[True,False,False],[True,True,False]]).shape)
"""
布尔值个数注意对应
"""
本文为参考龙龙老师的“深度学习与TensorFlow 2入门实战“课程书写的学习笔记
by CyrusMay 2022 04 06