深度学习(17)TensorFlow高阶操作六: 高阶OP
Outline
- where
- scatter_nd
- meshgrid
1. Where(tensor)
where只有一个参数tensor的时候,会返回这个tensor中所有值为True的坐标。
(1) mask=a>0
: 将a中大于0的元素标位True,小于0的元素标为False,最后形成一个mask;
(2) tf.boolean_mask(a, mask)
: 将a按照mask来筛选出大于0的元素;
(3) indices = tf.where(mask)
: 利用where()函数将mask中值为True的坐标拿出来;
(4) tf.gather_nd(a, indices)
: 利用gather_nd()函数将a按照indices取出元素;
2. where(cond, A, B)
如果where()中有3个参数,那么就会根据True/False以及A和B来构建新的Tensor;
(1) tf.where(mask, A, B)
: 根据mask中True和False的分布来构建新的Tensor,其中,True的值对应A中的元素值1,False的值对应B中的元素值0。
3. 1-D scatter_nd
根据坐标进行有目的性地更新
- tf.scatter_nd(
- indices,
- updates,
- shape
其中,shape为模板,updates里放的是坐标,output是结合了shape和updates的坐标;
(1) tf.scatter_nd(indices, updates, shape)
: 将indices和updates对应后放入到shape中,即: 第5个元素放入9,第4个元素放入10,第2个元素放入11,第8个元素放入12。
4. 2-D scatter_nd
(1) tf.scatter_nd(indices, updates, shape)
: 将indices和updates的元素对应后放入到shape中;
5. meshgrid
在3-D图像上利用坐标将函数画出来。
(1) Points(点的范围)
- [y, x, 2]
- [5, 5, 2]
- [N, 2]
如上图所示,一共有25个点,每个点由2个坐标(x和y)来表示,这样其shape=[25, 2];
(2) Numpy
其中y为[-2, 2]间隔5个点; x为[-2, 2]间隔5个点; 再将这些点保存到points=[]里;
这个方法是利用Numpy来实现的,没有经过GPU加速,而且无法与TensorFlow深度结合在一起。
(3) GPU acceleration
- x: [-2~2]
- y: [-2~2]
→ \to → - Points: [N, 2]
y = tf.linspace(-2., 2, 5)
: 给出y的范围,即[-2., -1., 0., 1., 2.];
x = tf.linspace(-2., 2, 5)
: 给出x的范围,即[-2., -1., 0., 1., 2.];
points_x, points_y = tf.meshgrid(x, y)
: 将坐标(x, y)中的x和y拆开分别保存到两个Tensor中去,即points_x和points_y;
(4) 还原
Points: [N, 2]
points = tf.stack([points_x, points_y], axis=2)
: 利用stack()函数进行第3个维度的合并;
(5) 应用
例如,我们需要画出如下函数的曲线或者等高线。
z
=
s
i
n
(
x
)
+
s
i
n
(
y
)
z=sin(x)+sin(y)
z=sin(x)+sin(y)
6. meshgrid应用实例
import tensorflow as tf
import matplotlib.pyplot as plt
def func(x):
"""
:param x: [b, 2]
:return: z
"""
z = tf.math.sin(x[..., 0]) + tf.math.sin(x[..., 1])
return z
x = tf.linspace(0., 2*3.14, 500)
y = tf.linspace(0., 2*3.14, 500)
# [50, 50]
point_x, point_y = tf.meshgrid(x, y)
# [50, 50, 2]
points = tf.stack([point_x, point_y], axis=2)
# points = tf.reshape(points, [-1, -2])
print('points:', points.shape)
z = func(points)
print('z:', z.shape)
plt.figure('plot 2d func value')
plt.imshow(z, origin='lower', interpolation='none')
plt.colorbar()
plt.figure('plot 2d func contour')
plt.contour(point_x, point_y, z)
plt.colorbar()
plt.show()
运行结果如下:
参考文献:
[1] 龙良曲:《深度学习与TensorFlow2入门实战》
[2] https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/scatter_nd