TensorFlow:高级操作

一、tf.gather

tf.gather 可以实现根据索引号收集数据的目的。

考虑班级成绩册的例子,共有4 个班级,每个班级35 个学生,8 门科目,保存成绩册的张量shape 为[4,35,8]。

x = tf.random.uniform([4,35,8],maxval=100,dtype=tf.int32)

现在需要收集第1-2 个班级的成绩册,可以给定需要收集班级的索引号:[0,1],班级的维
度axis=0:

tf.gather(x,[0,1],axis=0) # 在班级维度收集第1-2 号班级成绩册
# 收集第1,4,9,12,13,27 号同学成绩
tf.gather(x,[0,3,8,11,12,26],axis=1)
#如果需要收集所有同学的第3,5 等科目的成绩,则可以:
tf.gather(x,[2,4],axis=2) # 第3,5 科目的成绩

可以看到,tf.gather 非常适合索引没有规则的场合,其中索引号可以乱序排列,此时收集的数据也是对应顺序

tf.gather(a,[3,1,0,2],axis=0) # 收集第4,2,1,3 号元素

我们将问题变得复杂一点:如果希望抽查第[2,3]班级的第[3,4,6,27]号同学的科目成
绩,则可以通过组合多个tf.gather 实现。首先抽出第[2,3]班级:

students=tf.gather(x,[1,2],axis=0) # 收集第2,3 号班级
tf.gather(students,[2,3,5,26],axis=1) # 收集第3,4,6,27 号同学

二、tf.gather_nd

通过 tf.gather_nd,可以通过指定每次采样的坐标来实现采样多个点的目的

我们希望抽查第2 个班级的第2 个同学的所有科目,第3 个班级的第3 个同学的
所有科目,第4 个班级的第4 个同学的所有科目。那么这3 个采样点的索引坐标可以记
为:[1,1], [2,2], [3,3],我们将这个采样方案合并为一个List 参数:[[1,1], [2,2], [3,3]],通过
tf.gather_nd 实现如下:

# 根据多维度坐标收集数据
tf.gather_nd(x,[[1,1],[2,2],[3,3]])

一般地,在使用tf.gather_nd 采样多个样本时,如果希望采样第 i 号班级,第 j 个学
生,第 k 门科目的成绩,则可以表达为[. . . , [𝑖, 𝑗, 𝑘], . . . ],外层的括号长度为采样样本的个数,内层列表包含了每个采样点的索引坐标:

tf.gather_nd(x,[[1,1,2],[2,2,3],[3,3,4]])

上述代码中,我们抽出了班级1,学生1 的科目2;班级2,学生2 的科目3;班级3,学
生3 的科目4 的成绩,共有3 个成绩数据,结果汇总为一个shape 为[3]的张量。

三、tf.where

通过 tf.where(cond, a, b)操作可以根据 cond 条件的真假从 a 或 b 中读取数据,条件判定规则如下:
在这里插入图片描述
其中 i 为张量的索引,返回张量大小与a,b 张量一致,当对应位置中𝑐𝑜𝑛𝑑𝑖为True,𝑜𝑖位置
从𝑎𝑖中复制数据;当对应位置中𝑐𝑜𝑛𝑑𝑖为False,𝑜𝑖位置从𝑏𝑖中复制数据。
考虑从2 个全1、全0 的3x3 大小的张量a,b 中提取数据,其中cond 为True 的位置从a 中对应位置提取,cond 为False 的位置从b 对应位置提取:
在这里插入图片描述
可以看到,返回的张量中为1 的位置来自张量a,返回的张量中为0 的位置来自张量b。

当 a=b=None 即a,b 参数不指定时,tf.where 会返回cond 张量中所有True 的元素的索引坐标。考虑如下cond 张量:
在这里插入图片描述
其中True 共出现4 次,每个True 位置处的索引分布为[0,0], [1,1], [2,0], [2,1],可以直接通
tf.where(cond)来获得这些索引坐标

那么这有什么用途呢?考虑一个例子,我们需要提取张量中所有正数的数据和索引。首先
构造张量a,并通过比较运算得到所有正数的位置掩码:
在这里插入图片描述
通过 tf.where 提取此掩码处 TRUE 元素的索引:

indices = tf.where(mask)

拿到索引后,通过tf.gather_nd 即可恢复出所有正数的元素:

tf.gather_nd(x,indices) # 提取正数的元素值

在这里插入图片描述
实际上,当我们得到掩码 mask 之后,也可以直接通过tf.boolean_mask 获取对应元素:
在这里插入图片描述

四、scatter_nd

通过 tf.scatter_nd(indices, updates, shape)可以高效地刷新张量的部分数据,但是只能在全0 张量的白板上面刷新,因此可能需要结合其他操作来实现现有张量的数据刷新功能。

如下图所示,演示了一维张量白板的刷新运算,白板的形状表示为shape 参数,需要刷新的数据索引为indices,新数据为updates,其中每个需要刷新的数据对应在白板中的位置,根据indices 给出的索引位置将updates 中新的数据依次写入白板中,并返回更新后的白板张量。
在这里插入图片描述

# 构造需要刷新数据的位置
indices = tf.constant([[4],[3],[1],[7]])
# 构造需要写入的数据
updates = tf.constant([4.4, 3.3, 1.1, 7.7])
# 在长度为8 的全0 向量上根据indices 写入updates
tf.scatter_nd(indices, updates, [8])

考虑 3 维张量的刷新例子,如下图所示,白板shape 为[4,4,4],共有4 个通道的特
征图,现有需2 个通道的新数据updates:[2,4,4],需要写入索引为[1,3]的通道上:
在这里插入图片描述
在这里插入图片描述

五、meshgrid

通过 tf.meshgrid 可以方便地生成二维网格采样点坐标,方便可视化等应用场合。
考虑2 个自变量x,y 的Sinc 函数表达式为:
在这里插入图片描述
如果需要绘制函数在𝑥 ∈ [−8,8], 𝑦 ∈ [−8,8]区间的Sinc 函数的3D 曲面,如图所示,则首先需要生成x,y 的网格点坐标{(𝑥, 𝑦)},这样才能通过Sinc 函数的表达式计算函数在每个
(𝑥, 𝑦)位置的输出值z。

通过在 x 轴上进行采样100 个数据点,y 轴上采样100 个数据点,然后通过
tf.meshgrid(x, y)即可返回这10000 个数据点的张量数据,shape 为[100,100,2]。为了方便计算,tf.meshgrid 会返回在axis=2 维度切割后的2 个张量a,b,其中张量a 包含了所有点的x坐标,b 包含了所有点的y 坐标,shape 都为[100,100]:

x = tf.linspace(-8.,8,100) # 设置x 坐标的间隔
y = tf.linspace(-8.,8,100) # 设置y 坐标的间隔
x,y = tf.meshgrid(x,y) # 生成网格点,并拆分后返回
x.shape,y.shape # 打印拆分后的所有点的x,y 坐标张量shape

Sinc 函数在TensorFlow 中实现如下:

z = tf.sqrt(x**2+y**2)
z = tf.sin(z)/z # sinc 函数实现

通过matplotlib 即可绘制出函数在𝑥 ∈ [−8,8], 𝑦 ∈ [−8,8]区间的3D 曲面

fig = plt.figure()
ax = Axes3D(fig)
# 根据网格点绘制sinc 函数3D 曲面
ax.contour3D(x.numpy(), y.numpy(), z.numpy(), 50)
plt.show()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

南淮北安

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值