TopK 问题三种方法总结
文章目录
- TopK 问题三种方法总结
- what's topK?
- quickSelect
- 二分法
- 堆
- 经典例题
- leetcode-[719. Find K-th Smallest Pair Distance](https://leetcode.cn/problems/find-k-th-smallest-pair-distance/)
- leetcode-[373. Find K Pairs with Smallest Sums](https://leetcode.cn/problems/find-k-pairs-with-smallest-sums/)
- leetcode-[378. Kth Smallest Element in a Sorted Matrix](https://leetcode.cn/problems/kth-smallest-element-in-a-sorted-matrix/)
references:
快速选择排序 Quick select 解决Top K 问题
what’s topK?
topK问题是实际应用中涉及面较广的一个抽象问题,譬如:从20亿个数字的文本中,找出最大的前100个。
引入
看到这个问题你可能自然而然的想到了排序,无论是平均时间复杂度为 O(NlogN)
的快排,还是时间复杂度为 O(NlogN)
的归并排序和堆排序都可以,但问题是如果 N 很大呢?
有没有一种方法不需要对所有元素进行排序呢?
且看使用冒泡排序或者选择排序,那么时间复杂度就是 O(Nk)
了,和刚刚提到的方法哪个更优呢,这取决于 logN
和k
的大小。
quickSelect
解题步骤
- swap函数交换元素位置
- partition 按照快排的分割思想,pivot左边是比pivot小的所有数,返回pivot所在位置
- 核查pivot-left+1和k大小比较
- 如果大于k那么topK就在pivot左侧的这些数里面
- 如果等于k那么topK就是pivot所在位置的值
- 如果小于k那么寻找pivot右侧的top
k-(pivot-left+1)
即可
tip:
注意基准值所在位置 pivot_idx 不要再进入递归中
具体实现
const swap=(arr,a,b)=>{
let temp=arr[a];
arr[a]=arr[b];
arr[b]=temp;
};
const partition=(arr,k,left,right)=>{
let pivot=left,lessThan=left;
if(left===right) return left;
for(let i=left;i<=right;i++){
if(arr[i]<arr[pivot]){
lessThan++;
swap(arr,lessThan,i);
}
}
swap(arr,lessThan,pivot);
return lessThan;
};
const quickSelect=(arr,k,left,right)=>{
let idx=partition(arr,k,left,right);
// 一定要注意:idx已经被检查过所以不能再将其加入重新进行检查
if(right-idx+1>k){
return quickSelect(arr,k,idx+1,right);
}else if(right-idx+1===k){
return arr[idx];
}else{
return quickSelect(arr,k-(right-idx+1),left,idx-1);
}
};
const topK=(arr,k)=>{
return quickSelect(arr,k,0,arr.length-1);
};
summary
quickSelect
方法 与Quick sort不同的是,Quick select只考虑所寻找的目标所在的那一部分子数组,而非像Quick sort一样分别再对两边进行分 割。正是因为如此,Quick select将平均时间复杂度从O(nlogn)
降到了O(n)
,但与此同时,QuickSelect与QuickSort一样,是一个不稳定的算法;pivot选取直接影响了算法的好坏,worst case下的时间复杂度达到了O(n^2)
- 时间复杂度最大为:O(n^2)
- 时间复杂度最小为:O(n)
- 空间复杂度:O(1)
二分法
关于二分查找法的详解可以参考我的另一篇文章:详解二分查找法
假如 N 个数中最大值是 V m a x V_{max} Vmax 最小值是 V m i n V_{min} Vmin,那么要求的第 k 大的值一定在 [ V m a x V_{max} Vmax , V m i n V_{min} Vmin] ,我们可以利用二分查找对比原数组中 > = m i d >=mid >=mid 的值的个数,利用求 target 右边界的方式求得答案。
解题步骤
- 找到最大值max和最小值min
- 计算 mid 值开始循环查找
- 注意我们的max可以等于min,因为max min均为数组中的值
- 如果比mid大的数目大于等于k,那么mid过小,min变成mid+1,即求 bisect_right
- 为什么要把等于k的情况加入呢,因为等于k的情况下,选取的mid是小于等于目标值的
- 如果比mid大的数目小于k,那么mid过大,max变成mid-1
- 返回 min 值即为我们想要的
具体实现
const topK2=(arr,k)=>{
let low=Number.MAX_SAFE_INTEGER,high=-Number.MAX_SAFE_INTEGER;
for(let i=0;i<arr.length;i++){
if(arr[i]<low){
low=arr[i];
}
if(arr[i]>high){
high=arr[i];
}
}
// count用来计算数组中大于等于 mid 的个数
let cn;
while(low<=high){
let mid=Math.floor((low+high)/2);
cn=count(arr,mid);
// 这里其实是求 target 右边界
if(cn>=k){
// mid is too small
low=mid+1;
}else{
// mid is too big
// 这里可以做层优化,缩减 k 的值,让 count 在 mid 和 high 中寻找大于等于 mid 的数目
high=mid-1;
}
}
return right;
};
// ========= 扩展 =========
/**
* 如何改造成求最小的第k个值呢,使用求左边界的方式
* @param arr
* @param k
* @returns {number}
*/
const topK3=(arr,k)=>{
const count=(arr,target)=>{
let count=0;
for(let i=0;i<arr.length;i++){
if(arr[i]<=target){
count++;
}
}
return count;
};
let low=Number.MAX_SAFE_INTEGER,high=-Number.MAX_SAFE_INTEGER;
for(let i=0;i<arr.length;i++){
if(arr[i]<low){
low=arr[i];
}
if(arr[i]>high){
high=arr[i];
}
}
// count() 用来计算原数组中小于等于 target 的数量
let cn;
while(low<=high){
let mid=Math.floor((low+high)/2);
cn=count(arr,mid);
if(cn>=k){
// mid is too big
high=mid-1;
}else{
// mid is too small
low=mid+1;
}
}
return low;
};
summary
- 时间复杂度:
O(Nlog(max_val-min_val))
- 空间复杂度:O(1)
leetcode-719-找出第 k 小的距离对
首先这个题就是典型的topk问题,但是用构建最大堆的方式并不能通过,因此这边可以考虑用二分法的方式解决,既然采用二分法套用上面模板即可,但是难点在于如何遍历所有找出来小于mid的所有距离对的个数呢?(遍历所有显然不现实)
此处查找count数目用到了双指针:
- 维护一个right 递增
- 查找一个最小的left满足
arr[right]-arr[left]<=mid
- 叠加
right-left
的值即为小于mid的所有距离对个数。(至于为何是right-left
可以通过自己举例验证比如1到4之间有多少距离对:3+2+1
)
const count=(arr,target)=>{
let res=0,left=0;
for(let right=0;right<arr.length;right++){
while(arr[right]-arr[left]>target)left++;
res+=right-left;
}
return res;
};
const smallestDistancePair1=(arr,k)=>{
arr.sort((a,b)=>a-b);
let min=0,max=arr[arr.length-1]-arr[0];
while(min<=max){
let mid=Math.floor((min+max)/2);
let cn=count(arr,mid);
if(cn>=k){
// mid过大,===k的情况时有可能取的mid也是一个大于目标值的数
max=mid-1;
}else{
min=mid+1;
}
}
return min;
};
leetcode-378-有序矩阵中第K小的元素
解题思路可以参考我的题解二分法解决or最小堆解决,也可以参考下文
堆
具体可以参考我的另一篇文章:数据结构javascript描述中对于heap
的总结。
假设我们已经构造了一个存有 k 个元素的小根堆,根节点元素就是第 k 大元素也就是后 k 个元素中最小的元素,假如后面继续遍历,有个元素 a 比 根节点还小,假设依然成立,如果比它大那么假设就不成立了,此时将根节点换成 a 并重新 heapify,如此直到遍历完所有元素,得到真正的后 k 个元素组成的小根堆。
解题步骤
- 首先把数组中的前k个值用来建一个小根堆。
- 之后的其他数拿到之后进行判断,如果遇到比小顶堆的堆顶的值大的,将堆顶元素替换为该元素,并重新 heapify
- 最终的堆会是数组中最大的 K 个数组成的结构,小根堆的顶部又是结构中的最小数,因此把堆顶的值弹出即可得到Top-K。
Tips:
Topk 问题就用小根堆解决,Lowk 问题就用大根堆解决。
具体实现
// 注意如果求的不是 topK 而是 lowK 则是用 大根堆
import Heap from './algorithm/Heap/MinHeap';
const topK=(arr,k)=>{
let h=new Heap();
for(let i=0;i<k;i++){
h.insert(arr[i]);
}
for(let i=k;i<arr.length;i++){
if (arr[i]>h.data[0]){
h.deleting();
h.insert(arr[i]);
}
}
return h.data[0];
};
summary
- 时间复杂度:
O(NlogK)
其中N
为数组全部长度,K
即为要求的K
(因为堆中元素的数目永远是K
) - 空间复杂度:
O(K)
经典例题
要使用哪种方法解决问题需要根据实际题目做出选择:
leetcode-719. Find K-th Smallest Pair Distance
最佳方法:二分法。其他方法也可以,但复杂度过高。
其他技巧:灵活应用双指针解决问题
class Solution:
def smallestDistancePair(self, nums: List[int], k: int) -> int:
"""
优化:count 时因为数组是有序的,除了二分查找还可以使用双指针
时间复杂度:O(n*logD) 其中 D=max(nums)-min(nums)
空间复杂度:O(logn) 主要是排序占据的
:param nums:
:param k:
:return:
"""
nums.sort()
length = len(nums)
def count(d: int):
"""
寻找距离小于等于 d 的个数
:param d:
:return:
"""
res = 0
j = 0
for i in range(length):
while nums[i] - nums[j] > d:
j += 1
res += i - j
return res
def my_bisect_left(l: int, r: int, target: int, f: Callable) -> int:
"""
返回 target 位于原数组最左侧可以插入的位置
:param l:
:param r:
:param target:
:param f:
:return:
"""
while l <= r:
mid = (l + r) // 2
if f(mid) < target:
l = mid + 1
else:
r = mid - 1
return l
max_val = nums[-1] - nums[0]
# Tip: 因为是要求第 k 个,所以要按小于等于 k 来算
return my_bisect_left(0, max_val, k, count)
leetcode-373. Find K Pairs with Smallest Sums
可以使用二分法解决,但题目找前 k 个而不是第 k 个,因此需要特殊处理一下
根据题目规律,利用小根堆解决问题
class Solution:
def kSmallestPairs_bisect(self, nums1: List[int], nums2: List[int], k: int) -> List[List[int]]:
"""
首先读懂题意,从 nums1 nums2 中各取一值,取和排序为前 k 个的索引对。
方法一:穷举所有可能的组合,排序取前 k 个即可,时间复杂度:O(mnlog(mn))
方法二:已知最小值,最大值,可以用二分法求 lowk 的方式,
时间复杂度:O((m+n)log(max-min))
空间复杂度:O(k)
:param nums1:
:param nums2:
:param k:
:return:
"""
m, n = len(nums1), len(nums2)
min_val, max_val = nums1[0] + nums2[0], nums1[-1] + nums2[-1]
def count(target: int) -> int:
# [1,7,11]
# [2,4,6] 3
# 3+17 <=10 4
# 3+9 <=6 2
# 7+9 <=7 3
j = n - 1
count = 0
for i in range(m):
while j >= 0 and nums1[i] + nums2[j] > target:
j -= 1
count += j + 1
return count
while min_val <= max_val:
mid = (min_val + max_val) // 2
if count(mid) < k:
min_val = mid + 1
else:
max_val = mid - 1
# min_val 即为所求 lowk
# print(min_val)
# 但由于 min_val 不仅仅可能是 lowk 也可能是 lowk+1 lowk+2 针对这种情景就需要处理 equal 的场景
res, equal = [], []
j = n - 1
for i in range(m):
while j >= 0 and nums1[i] + nums2[j] > min_val:
j -= 1
for x in range(j + 1):
if nums1[i] + nums2[x] == min_val:
equal.append([nums1[i], nums2[x]])
else:
res.append([nums1[i], nums2[x]])
if len(res) < k:
res.extend(equal[0:k - len(res)])
return res
def heapify(self, nums: List[List[int]], length: int, i: int):
"""
从 i 开始构建最小堆
:param i:
:param length:
:param nums:
:return:
"""
l, r = i * 2 + 1, i * 2 + 2
min_idx = i
if l < length and nums[l][2] < nums[min_idx][2]:
min_idx = l
if r < length and nums[r][2] < nums[min_idx][2]:
min_idx = r
if min_idx != i:
nums[i], nums[min_idx] = nums[min_idx], nums[i]
self.heapify(nums, length, min_idx)
def build_min_heap(self, nums: List[List[int]]):
length = len(nums)
# 最后一个非叶子节点所在索引
idx = length // 2 - 1
for i in range(idx, -1, -1):
self.heapify(nums, length, i)
def heap_extract_min(self, nums: List[List[int]]) -> List[int]:
min_val = nums[0]
length = len(nums)
nums[0], nums[length - 1] = nums[length - 1], nums[0]
nums.pop()
self.heapify(nums, length - 1, 0)
return min_val
def heap_decrease(self, nums: List[List[int]], i: int, val: List[int]):
"""
索引 i 处 修改为 val
:param nums:
:param i:
:param val:
:return:
"""
# 根节点
def get_root(idx: int) -> int:
return (idx + 1) // 2 - 1
nums[i] = val
while get_root(i) >= 0 and nums[get_root(i)][2] > nums[i][2]:
nums[get_root(i)], nums[i] = nums[i], nums[get_root(i)]
i = get_root(i)
def heap_push(self, nums: List[List[int]], val: List[int]):
nums.append(val)
length = len(nums)
self.heap_decrease(nums, length - 1, val)
def kSmallestPairs(self, nums1: List[int], nums2: List[int], k: int) -> List[List[int]]:
"""
方法三:先选 k 个值构建大根堆,然后遍历所有元素将其与 heap[0] 对比,得到答案,但这种方式需要遍历 mn 个元素,如果 nums1 nums2
数组较大,显然效率较低。但我们可以根据题目的规律构建小根堆来解决问题:
已知最小的是 (0,0), 下一个就是待比较的就是 (0,1) 和 (1,0)
假如下一个是 (0,1) 那么下一个要比较的是 (1,0) (1,1) (0,2) => (1,0) (0,2)
假如下一个是 (1,0) 那么下一个要比较的是 (0,1) (1,1) (2,0) => (0,1) (2,0)
假如下一个是 (1,0) 那么下一个要比较的是 (0,2) (1,1) (2,0)
其实就是上次比较的数,加上新的 (a+1,b) (a,b+1),但是每次都这么增加其实带有重复的情况,比如
(0,0)
(0,1) (1,0)
选 (0,1) 则 (1,0) 再加入 (1,1) (0,2)
选 (1,0) 则 (1,1) (0,2) 再加入 (1,1) (2,0) 此时重复了 (1,1)
如果一开始我们就有 (0,0) (1,0) (2,0)...(k-1,0), 找到最小值 (a,b) 之后每次都添加 (a,b+1) 则变成了
(0,0)
(1,0)... 再加入 (0,1)
选 (0,1) 则 (1,0) (2,0) ... 再加入 (0,2)
选 (1,0) 则 (2,0) (0,2) ... 再加入 (1,1) 满足需求
建堆:O(N)
extract_min: O(logN)
heap_push: O(logN)
时间复杂度:O(klogk)
空间复杂度:O(k)
:param nums1:
:param nums2:
:param k:
:return:
"""
m, n = len(nums1), len(nums2)
nums = [[i, 0, nums1[i] + nums2[0]] for i in range(min(k, m))]
self.build_min_heap(nums)
res = []
while nums and len(res) < k:
[i, j, _] = self.heap_extract_min(nums)
res.append([nums1[i], nums2[j]])
if j + 1 < n:
self.heap_push(nums, [i, j + 1, nums1[i] + nums2[j + 1]])
return res
leetcode-378. Kth Smallest Element in a Sorted Matrix
最佳方法是使用二分法获得 lowk
其他技巧:z 形搜索(主要针对横纵均有序的矩阵,利用第一行最后一个元素为基准进行搜索的算法)
class Solution:
def kthSmallest_0(self, matrix: List[List[int]], k: int) -> int:
"""
首先读懂题意,已知矩阵中每一行和列是增序的,求矩阵中所有元素的第 k 小的元素
方法一:直接对 n*n 个数字进行排序,时间复杂度 n*n*log(n*n) = O(2n^2logn)
方法二:quick select 期望时间复杂度为 O(n^2)
方法三:由于矩阵中元素其实是有序的,可以考虑使用二分法,得到 max_val, min_val,得到 mid 求小于等于 mid 的值的数量
时间复杂度:O(Nlog(max_val-min_val))
空间复杂度:O(1)
:param matrix:
:param k:
:return:
"""
n = len(matrix)
min_val, max_val = matrix[0][0], matrix[n - 1][n - 1]
def count(target: int, n: int) -> int:
"""
寻找小于等于 target 的个数
eg: 在 [[1,5,9],[10,11,13],[12,13,15]] 中寻找小于等于 8 的个数
1 5 9
10 11 13
12 13 15
如果使用 z 字形搜索,l=0,r=n-1,选这个点作为对比点,比它大的值肯定在下一行,比它小的则可以继续在该行左移,其他的所有点都类似,
因此可以利用 z 字形搜索找到要找的答案。
时间复杂度:O(2n) 即 O(n)
空间复杂度:O(1)
<= 12 的也利用 z 字形搜索,有 3+2+1=6
<=14,有 3+3+2=8 此时要缩减一下
<=13,有 3+3+2=8 此时再缩减一下,13+12//2=12
<=12,有 3+2+1=6 left=mid+1,变成 13,得到结果
:param target:
:return:
"""
left, right = 0, n - 1
res = 0
while left < n and right >= 0:
if matrix[left][right] <= target:
res += right + 1
left += 1
else:
right -= 1
return res
while min_val <= max_val:
mid = (min_val + max_val) // 2
# 注意这里是找 bisect_left
if count(mid, n) < k:
min_val = mid + 1
else:
max_val = mid - 1
return min_val
def heapify(self, nums: List[int], length: int, idx: int):
"""
max-heap heapify
:param nums:
:param length:
:param idx:
:return:
"""
max_idx = idx
l = idx * 2 + 1
r = idx * 2 + 2
if l < length and nums[l] > nums[max_idx]:
max_idx = l
if r < length and nums[r] > nums[max_idx]:
max_idx = r
if max_idx != idx:
nums[max_idx], nums[idx] = nums[idx], nums[max_idx]
self.heapify(nums, length, max_idx)
def build_max_heap(self, nums: List[int]):
length = len(nums)
last = length // 2 - 1
for i in range(last, -1, -1):
self.heapify(nums, length, i)
def heap_change_val(self, nums: List[int], idx: int, val: int):
nums[idx] = val
if idx == 0:
self.heapify(nums, len(nums), 0)
def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
"""
方法四:构建大根堆,审查所有元素,最后得到的根即为第 k 小的元素,时间复杂度 O(N^2logk)
:param matrix:
:param k:
:return:
"""
n = len(matrix)
length = min(k, n * n)
nums = []
total = 0
stop_i, stop_j = 0, 0
for i in range(n):
for j in range(n):
if total >= length:
break
nums.append(matrix[i][j])
stop_i, stop_j = i, j
total += 1
# 构建 max-heap 时间复杂度 O(k)
self.build_max_heap(nums)
# print('====>', nums, stop_i, stop_j)
for j in range(stop_j+1, n):
if nums[0] > matrix[stop_i][j]:
self.heap_change_val(nums, 0, matrix[stop_i][j])
for i in range(stop_i + 1, n):
for j in range(n):
if nums[0] > matrix[i][j]:
self.heap_change_val(nums, 0, matrix[i][j])
return nums[0]