python(np.argpartition())输出 前n个最大/小值所对应的索引

利用python numpy标准函数np.argpartition(),输出 前n个最大/小值所对应的索引(输出的索引是无序的)

需要注意:

1、输出索引所对应的数值是乱序的。这种方法复杂度要低

import numpy as np


def topk_partition(matrix, k, top_max=True):
    '''
    输出matrix矩阵的最大/小前k个值的index,无序
    所对应的top k的值:matrix[index[i][0], index[i][1]]
                        matrix[index[i]]
    :param matrix: 二维或一维
    :param k:
    :param top_max: True:最大值的前k个值,无序;False:最小值的前k个值,无序;
    :return:
    '''
    flag_2d = False
    if len(matrix.shape) == 2:
        flag_2d = True
        matrix1 = matrix.reshape((matrix.shape[0] * matrix.shape[1]))
    else:
        matrix1 = matrix.copy()
    if top_max: # matrix前k个最大值 所对应的index
        index = np.argpartition(matrix1, -k)[-k:len(matrix1)]
    else:   # matrix前k个最小值 所对应的index
        index = np.argpartition(matrix1, k-1)[0:k]

    # 解析index
    if flag_2d:
        index_2d = []
        for i in index:
            row = i // matrix.shape[1]
            col = i % matrix.shape[1]
            index_2d.append([row,col])
        index = index_2d

    return index

if __name__ == '__main__':
    # dists= np.array([ 6,0,2,1,5,1.2,4,22])
    dists = np.random.randint(0, 10, size=(3, 4))
    print(dists)
    index = topk_partition(dists, 4, top_max=False)
    print(index)
    # print(dists[index])

 

  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值