import tensorflow as tf
a = tf.constant([[-1, 2, 3],
[2, -3, 4],
[5, 6, -7]], dtype=tf.float32)
d = tf.ones(shape=[3, 3], name='d')
e = tf.zeros(shape=[3, 3], name='e')
b = tf.where(a < 0, d, e)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(b))
# output
# [[1. 0. 0.]
# [0. 1. 0.]
# [0. 0. 1.]]
tensorflow 生成一个tensor的mask
最新推荐文章于 2022-08-02 15:46:51 发布