深度学习(17)TensorFlow高阶操作六: 高阶OP


Outline

  • where
  • scatter_nd
  • meshgrid

1. Where(tensor)

where只有一个参数tensor的时候,会返回这个tensor中所有值为True的坐标。
在这里插入图片描述
在这里插入图片描述

(1) mask=a>0: 将a中大于0的元素标位True,小于0的元素标为False,最后形成一个mask;
(2) tf.boolean_mask(a, mask): 将a按照mask来筛选出大于0的元素;
(3) indices = tf.where(mask): 利用where()函数将mask中值为True的坐标拿出来;
(4) tf.gather_nd(a, indices): 利用gather_nd()函数将a按照indices取出元素;

2. where(cond, A, B)

如果where()中有3个参数,那么就会根据True/False以及A和B来构建新的Tensor;
在这里插入图片描述

(1) tf.where(mask, A, B): 根据mask中True和False的分布来构建新的Tensor,其中,True的值对应A中的元素值1,False的值对应B中的元素值0。

3. 1-D scatter_nd

根据坐标进行有目的性地更新

  • tf.scatter_nd(
    • indices,
    • updates,
    • shape
      在这里插入图片描述

其中,shape为模板,updates里放的是坐标,output是结合了shape和updates的坐标;
在这里插入图片描述

(1) tf.scatter_nd(indices, updates, shape): 将indices和updates对应后放入到shape中,即: 第5个元素放入9,第4个元素放入10,第2个元素放入11,第8个元素放入12。

4. 2-D scatter_nd

在这里插入图片描述
在这里插入图片描述

(1) tf.scatter_nd(indices, updates, shape): 将indices和updates的元素对应后放入到shape中;

5. meshgrid

在3-D图像上利用坐标将函数画出来。
在这里插入图片描述

(1) Points(点的范围)

  • [y, x, 2]
    • [5, 5, 2]
  • [N, 2]
    在这里插入图片描述

如上图所示,一共有25个点,每个点由2个坐标(x和y)来表示,这样其shape=[25, 2];

(2) Numpy

在这里插入图片描述

其中y为[-2, 2]间隔5个点; x为[-2, 2]间隔5个点; 再将这些点保存到points=[]里;
这个方法是利用Numpy来实现的,没有经过GPU加速,而且无法与TensorFlow深度结合在一起。

(3) GPU acceleration

  • x: [-2~2]
  • y: [-2~2]
    → \to
  • Points: [N, 2]
    在这里插入图片描述

y = tf.linspace(-2., 2, 5): 给出y的范围,即[-2., -1., 0., 1., 2.];
x = tf.linspace(-2., 2, 5): 给出x的范围,即[-2., -1., 0., 1., 2.];
points_x, points_y = tf.meshgrid(x, y): 将坐标(x, y)中的x和y拆开分别保存到两个Tensor中去,即points_x和points_y;
在这里插入图片描述

(4) 还原

Points: [N, 2]
在这里插入图片描述

points = tf.stack([points_x, points_y], axis=2): 利用stack()函数进行第3个维度的合并;

(5) 应用

例如,我们需要画出如下函数的曲线或者等高线。
z = s i n ⁡ ( x ) + s i n ⁡ ( y ) z=sin⁡(x)+sin⁡(y) z=sin(x)+sin(y)
在这里插入图片描述

6. meshgrid应用实例

import tensorflow as tf
import matplotlib.pyplot as plt


def func(x):
    """

    :param x: [b, 2]
    :return: z
    """
    z = tf.math.sin(x[..., 0]) + tf.math.sin(x[..., 1])

    return z


x = tf.linspace(0., 2*3.14, 500)
y = tf.linspace(0., 2*3.14, 500)
# [50, 50]
point_x, point_y = tf.meshgrid(x, y)
# [50, 50, 2]
points = tf.stack([point_x, point_y], axis=2)
# points = tf.reshape(points, [-1, -2])
print('points:', points.shape)
z = func(points)
print('z:', z.shape)

plt.figure('plot 2d func value')
plt.imshow(z, origin='lower', interpolation='none')
plt.colorbar()

plt.figure('plot 2d func contour')
plt.contour(point_x, point_y, z)
plt.colorbar()
plt.show()

运行结果如下:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

参考文献:
[1] 龙良曲:《深度学习与TensorFlow2入门实战》
[2] https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/scatter_nd

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值