tensorflow tf.nn.top_k 生成mask 提取值

# 通过生成boolean tensor的办法:
a = tf.convert_to_tensor([[40, 30, 20, 10], [10, 20, 30, 40]])
b = tf.nn.top_k(a, 2)

print(sess.run(b))
TopKV2(values=array([[40, 30],
   [40, 30]], dtype=int32), indices=array([[0, 1],
   [3, 2]], dtype=int32))

print(sess.run(b).values))
array([[40, 30],
       [40, 30]], dtype=int32)

kth = tf.reduce_min(b.values,1,keepdims=True) # 找出最小值
top2 = tf.greater_equal(a, kth) # 大于最小值的为true
print(sess.run(top2))
array([[ True,  True, False, False],
       [False, False,  True,  True]], dtype=bool)
# 通过生成id后scatter的办法:
import tensorflow as tf

# Input data
a = tf.placeholder(tf.float32, [None, None])
num_top = tf.placeholder(tf.int32, [])
# Find top elements
a_top, a_top_idx = tf.nn.top_k(a, num_top, sorted=False)
# Apply softmax
a_top_sm = tf.nn.softmax(a_top)
# Reconstruct into original shape
a_shape = tf.shape(a)
a_row_idx = tf.tile(tf.range(a_shape[0])[:, tf.newaxis], (1, num_top))
scatter_idx = tf.stack([a_row_idx, a_top_idx], axis=-1)  # 生成scatter_index
result = tf.scatter_nd(scatter_idx, a_top_sm, a_shape)  #生成矩阵
# Test
with tf.Session() as sess:
    result_val = sess.run(result, feed_dict={a: [[2, 5, 4, 7], [7, 5, 6, 8]], num_top: 2})
    print(result_val)

[[0.         0.11920291 0.         0.880797  ]
 [0.26894143 0.         0.         0.7310586 ]]
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值