tf.boolean_mask 的作用是 通过布尔值 过滤元素
def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
"""Apply boolean mask to tensor."""
参数解释:
tensor:被过滤的元素列表或数组
mask:一堆 bool 值,它的维度不一定等于 tensor
return: mask 为 true 对应的 tensor 的元素
当 tensor 与 mask 维度一致时,return 一维
# 1维的示例
tensor = [0, 1, 2, 3]
mask = np.array([True, False, True, False])
out = tf.boolean_mask(tensor, mask)
with tf.Session() as sess:
print(sess.run(out)) # [0, 2]
print(out.shape) # (?,)
再看看 mask 与 tensor 维度不同的例子
tensor = [[1, 2], [3, 4], [5, 6]]
mask = np.array([True, False, True]) # mask 与 tensor 维度不同
out2 = tf.boolean_mask(tensor, mask)
with tf.Session() as sess:
print(sess.run(out2)) # [[1, 2], [5, 6]]
print(out2.shape) # (?, 2)
mask 可以用一个函数代替
# 3-D
tensor = tf.constant([
[[2,4],[4,1]],
[[6,8],[2,1]]],tf.float32)
mask = tensor > 2 # 滤波器 mask 与 tensor 相同维度
out3 = tf.boolean_mask(tensor, mask)
with tf.Session() as sess:
print(sess.run(tensor))
print(sess.run(mask)) # [[[False True] [ True False]]
# [[ True True] [False False]]]
print(sess.run(out3)) # [4. 4. 6. 8.] 输出一维
print(out3.shape) # (?,)