函数原型
tf.boolean_mask(
tensor, mask, axis=None, name='boolean_mask'
)
函数说明
用于保留张量tenso的一部分,具体取决于张量mask的下标。
参数tensor表示待操作的张量,参数mask是一个布尔类型的矩阵,决定tensor中那些地方应该保留,参数axis表示从哪一个轴开始。
函数使用
1、一维张量
>>> tensor = tf.constant([1, 2, 3, 4])
>>> mask = tf.constant([True, True, False, False])
>>> tf.boolean_mask(tensor, mask)
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 2])>
2、二维张量,mask如果为一维张量,则需要与axis对应的轴的长度相等,如果为二维张量,则需要与tensor的shape相同。
>>> tensor = tf.constant([[1, 2], [3, 4]])
>>> mask = tf.constant([True, False])
# 默认从axis=0开始
>>> tf.boolean_mask(tensor, mask)
<tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[1, 2]])>
>>> tensor = tf.constant([[1, 2], [3, 4]])
>>> mask = tf.constant([True, False])
>>> tf.boolean_mask(tensor, mask, axis=1)
<tf.Tensor: shape=(2, 1), dtype=int32, numpy=
array([[1],
[3]])>
>>> tensor = tf.constant([[1, 2], [3, 4]])
>>> mask = tf.constant([[True, True], [True, False]])
>>> tf.boolean_mask(tensor, mask)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3])>