N个数选出M个最小或最大值,topk算法

一种遍历一次即可得到TOPK的算法


def get_topk(tensor_1d, topk=3):
    # value in topk_vals are placed by descending order
    topk_vals = [-float("Inf")] * topk
    topk_idxs = [0] * topk

    for idx, elem in enumerate(tensor_1d):
        if elem > topk_vals[topk - 1]:
            for i in range(topk):
                # find where current top value should be placed
                # then we right shift the topk_vals to place the top value
                if elem > topk_vals[i]:
                    # right shift
                    for j in reversed(range(i, topk-1)):
                        topk_vals[j+1] = topk_vals[j]
                        topk_idxs[j+1] = topk_idxs[j]
                    
                    topk_vals[i] = elem
                    topk_idxs[i] = idx
                    break
    return topk_vals, topk_idxs

tensor_1d=[1,2,3,4,4,5,5,6]

topk_vals, topk_idxs = get_topk(tensor_1d, topk=3)

print(topk_vals)
print(topk_idxs)

该方法保存top k元素的数组,然后遍历每个元素,依次向后与该数组元素比较,找到大于当前top元素的位置,然后从当前位置右移top k元素的数组并插入该top元素。

一种基于直方图的方法

有时候需要N个数选出M个最小或最大值算法,但并未要求选出的M个数据需要有序排列,那么这使得算法复杂度可以很低。这里给出一种可行的方法,比常见的一些方法具有更加显著的速度。

1,找出N个数据的最大和最小值。需要一次N个数据遍历。

2,根据最大最小值根据一个间隔创建一个直方图,N个数据遍历一次,进行直方图统计。例如,直方图间隔为k,则每次直方图(CurDat-MinDat)/K位置加1即可.

3,根据直方图从最小或最大处开始,找到累积大于等于M个元素的阈值,需要一次直方图遍历。

4,根据步骤3找到的阈值选出M1个数据,需要一次N个数据遍历。

5,由于M1可能略大于M,稍作后处理移除即可。可能需要对M1进行多余数据次数的排序。

该方法的缺点则是第3步缺乏准确性,导致选出的数据会略大于需求个数,在一些极端情况下可能工作不好。但在我的应用中,未限定死选出的最小数必须为M个,只需要近似即可,因此该方法具有较好实现性和性能。

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Luchang-Li

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值