tf.dynamic_stitch 和  tf.dynamic_partition

import  tensorflow as tf
x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4])

#判断x里面的元素是否是1
condition_mask=tf.not_equal(x,tf.constant(-1.))

#[ True, False,  True,  True, False,  True]

#将张量拆成两个,按照condition_mask的对应位置
partitioned_data = tf.dynamic_partition(
    x, tf.cast(condition_mask, tf.int32) , 2)


#partitioned_data[0]=[-1., -1.]
#partitioned_data[1]=[2.1, 7.2, 6.3, 9.4]


partitioned_data[1] = partitioned_data[1] + 1.0



#这行代码是提取索引位置
condition_indices = tf.dynamic_partition(
    tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2)


x = tf.dynamic_stitch(condition_indices, partitioned_data)
# Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain
# unchanged.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值