在numpy中使用O(1)的np.argpartition()
方法
np.argpartition
官网文档:https://numpy.org/doc/stable/reference/generated/numpy.argpartition.html
对于一个array
,使用:
np.argpartition(array, kth=8)
其中kth
代表第几大的数据,它通过将整个数组切分成两个部分,左边是比第k大的数小的部分,右边是比第k大的数据大的部分
反过来同理,下面的函数加一个负号就代表这是第8小的数据:
np.argpartition(array, kth=-8)
示例代码
import numpy as np
def get_k_max(array, k):
_k_sort = np.argpartition(array, -k)[-k:] # 最大的k个数据的下标
return array[_k_sort]
def get_k_min(array, k):
_k_sort = np.argpartition(array, k)[:k] # 最小的k个数据的下标
return array[_k_sort]
if __name__ == '__main__':
a = np.array([10, 2, 3, 40, 5, 6, 7, 8, 9])
print("得到TOP3的数据:", get_k_max(a, 3))
print("得到3个最小的数据:", get_k_min(a, 3))