import tensorflow as tf
x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4])
#判断x里面的元素是否是1
condition_mask=tf.not_equal(x,tf.constant(-1.))
#[ True, False, True, True, False, True]
#将张量拆成两个,按照condition_mask的对应位置
partitioned_data = tf.dynamic_partition(
x, tf.cast(condition_mask, tf.int32) , 2)
#partitioned_data[0]=[-1., -1.]
#partitioned_data[1]=[2.1, 7.2, 6.3, 9.4]
partitioned_data[1] = partitioned_data[1] + 1.0
#这行代码是提取索引位置
condition_indices = tf.dynamic_partition(
tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2)
x = tf.dynamic_stitch(condition_indices, partitioned_data)
# Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain
# unchanged.
tf.dynamic_stitch 和 tf.dynamic_partition
最新推荐文章于 2022-09-13 15:31:43 发布