一、tf.gather
tf.gather
可以实现根据索引号收集数据的目的。
考虑班级成绩册的例子,共有4 个班级,每个班级35 个学生,8 门科目,保存成绩册的张量shape 为[4,35,8]。
x = tf.random.uniform([4,35,8],maxval=100,dtype=tf.int32)
现在需要收集第1-2 个班级的成绩册,可以给定需要收集班级的索引号:[0,1],班级的维
度axis=0:
tf.gather(x,[0,1],axis=0) # 在班级维度收集第1-2 号班级成绩册
# 收集第1,4,9,12,13,27 号同学成绩
tf.gather(x,[0,3,8,11,12,26],axis=1)
#如果需要收集所有同学的第3,5 等科目的成绩,则可以:
tf.gather(x,[2,4],axis=2) # 第3,5 科目的成绩
可以看到,tf.gather 非常适合索引没有规则的场合,其中索引号可以乱序排列,此时收集的数据也是对应顺序:
tf.gather(a,[3,1,0,2],axis=0) # 收集第4,2,1,3 号元素
我们将问题变得复杂一点:如果希望抽查第[2,3]班级的第[3,4,6,27]号同学的科目成
绩,则可以通过组合多个tf.gather 实现。首先抽出第[2,3]班级:
students=tf.gather(x,[1,2],axis=0) # 收集第2,3 号班级
tf.gather(students,[2,3,5,26],axis=1) # 收集第3,4,6,27 号同学
二、tf.gather_nd
通过 tf.gather_nd
,可以通过指定每次采样的坐标来实现采样多个点的目的
我们希望抽查第2 个班级的第2 个同学的所有科目,第3 个班级的第3 个同学的
所有科目,第4 个班级的第4 个同学的所有科目。那么这3 个采样点的索引坐标可以记
为:[1,1], [2,2], [3,3],我们将这个采样方案合并为一个List 参数:[[1,1], [2,2], [3,3]],通过
tf.gather_nd 实现如下:
# 根据多维度坐标收集数据
tf.gather_nd(x,[[1,1],[2,2],[3,3]])
一般地,在使用tf.gather_nd
采样多个样本时,如果希望采样第 i 号班级,第 j 个学
生,第 k 门科目的成绩,则可以表达为[. . . , [𝑖, 𝑗, 𝑘], . . . ]
,外层的括号长度为采样样本的个数,内层列表包含了每个采样点的索引坐标:
tf.gather_nd(x,[[1,1,2],[2,2,3],[3,3,4]])
上述代码中,我们抽出了班级1,学生1 的科目2;班级2,学生2 的科目3;班级3,学
生3 的科目4 的成绩,共有3 个成绩数据,结果汇总为一个shape 为[3]的张量。
三、tf.where
通过 tf.where(cond, a, b)
操作可以根据 cond 条件的真假从 a 或 b 中读取数据,条件判定规则如下:
其中 i 为张量的索引,返回张量大小与a,b 张量一致,当对应位置中𝑐𝑜𝑛𝑑𝑖为True,𝑜𝑖位置
从𝑎𝑖中复制数据;当对应位置中𝑐𝑜𝑛𝑑𝑖为False,𝑜𝑖位置从𝑏𝑖中复制数据。
考虑从2 个全1、全0 的3x3 大小的张量a,b 中提取数据,其中cond 为True 的位置从a 中对应位置提取,cond 为False 的位置从b 对应位置提取:
可以看到,返回的张量中为1 的位置来自张量a,返回的张量中为0 的位置来自张量b。
当 a=b=None 即a,b 参数不指定时,tf.where
会返回cond 张量中所有True 的元素的索引坐标。考虑如下cond 张量:
其中True 共出现4 次,每个True 位置处的索引分布为[0,0], [1,1], [2,0], [2,1],可以直接通
过tf.where(cond)
来获得这些索引坐标
那么这有什么用途呢?考虑一个例子,我们需要提取张量中所有正数的数据和索引。首先
构造张量a,并通过比较运算得到所有正数的位置掩码:
通过 tf.where
提取此掩码处 TRUE 元素的索引:
indices = tf.where(mask)
拿到索引后,通过tf.gather_nd
即可恢复出所有正数的元素:
tf.gather_nd(x,indices) # 提取正数的元素值
实际上,当我们得到掩码 mask 之后,也可以直接通过tf.boolean_mask
获取对应元素:
四、scatter_nd
通过 tf.scatter_nd(indices, updates, shape)
可以高效地刷新张量的部分数据,但是只能在全0 张量的白板上面刷新,因此可能需要结合其他操作来实现现有张量的数据刷新功能。
如下图所示,演示了一维张量白板的刷新运算,白板的形状表示为shape 参数,需要刷新的数据索引为indices,新数据为updates,其中每个需要刷新的数据对应在白板中的位置,根据indices 给出的索引位置将updates 中新的数据依次写入白板中,并返回更新后的白板张量。
# 构造需要刷新数据的位置
indices = tf.constant([[4],[3],[1],[7]])
# 构造需要写入的数据
updates = tf.constant([4.4, 3.3, 1.1, 7.7])
# 在长度为8 的全0 向量上根据indices 写入updates
tf.scatter_nd(indices, updates, [8])
考虑 3 维张量的刷新例子,如下图所示,白板shape 为[4,4,4],共有4 个通道的特
征图,现有需2 个通道的新数据updates:[2,4,4],需要写入索引为[1,3]的通道上:
五、meshgrid
通过 tf.meshgrid
可以方便地生成二维网格采样点坐标,方便可视化等应用场合。
考虑2 个自变量x,y 的Sinc 函数表达式为:
如果需要绘制函数在𝑥 ∈ [−8,8], 𝑦 ∈ [−8,8]区间的Sinc 函数的3D 曲面,如图所示,则首先需要生成x,y 的网格点坐标{(𝑥, 𝑦)},这样才能通过Sinc 函数的表达式计算函数在每个
(𝑥, 𝑦)位置的输出值z。
通过在 x 轴上进行采样100 个数据点,y 轴上采样100 个数据点,然后通过
tf.meshgrid(x, y)
即可返回这10000 个数据点的张量数据,shape 为[100,100,2]。为了方便计算,tf.meshgrid 会返回在axis=2 维度切割后的2 个张量a,b,其中张量a 包含了所有点的x坐标,b 包含了所有点的y 坐标,shape 都为[100,100]:
x = tf.linspace(-8.,8,100) # 设置x 坐标的间隔
y = tf.linspace(-8.,8,100) # 设置y 坐标的间隔
x,y = tf.meshgrid(x,y) # 生成网格点,并拆分后返回
x.shape,y.shape # 打印拆分后的所有点的x,y 坐标张量shape
Sinc 函数在TensorFlow 中实现如下:
z = tf.sqrt(x**2+y**2)
z = tf.sin(z)/z # sinc 函数实现
通过matplotlib 即可绘制出函数在𝑥 ∈ [−8,8], 𝑦 ∈ [−8,8]区间的3D 曲面
fig = plt.figure()
ax = Axes3D(fig)
# 根据网格点绘制sinc 函数3D 曲面
ax.contour3D(x.numpy(), y.numpy(), z.numpy(), 50)
plt.show()