TopK问题三种方法总结

TopK 问题三种方法总结


references:

Top K问题的两种解决思路

TOP-K问题的几种解法

Top-K问题的几种求解方法

快速选择排序 Quick select 解决Top K 问题

what’s topK?

topK问题是实际应用中涉及面较广的一个抽象问题,譬如:从20亿个数字的文本中,找出最大的前100个。

引入

看到这个问题你可能自然而然的想到了排序,无论是平均时间复杂度为 O(NlogN)的快排,还是时间复杂度为 O(NlogN)的归并排序和堆排序都可以,但问题是如果 N 很大呢?

有没有一种方法不需要对所有元素进行排序呢?

且看使用冒泡排序或者选择排序,那么时间复杂度就是 O(Nk)了,和刚刚提到的方法哪个更优呢,这取决于 logNk的大小。

quickSelect

解题步骤

  • swap函数交换元素位置
  • partition 按照快排的分割思想,pivot左边是比pivot小的所有数,返回pivot所在位置
  • 核查pivot-left+1和k大小比较
    • 如果大于k那么topK就在pivot左侧的这些数里面
    • 如果等于k那么topK就是pivot所在位置的值
    • 如果小于k那么寻找pivot右侧的topk-(pivot-left+1)即可

tip:

注意基准值所在位置 pivot_idx 不要再进入递归中

具体实现

const swap=(arr,a,b)=>{
    let temp=arr[a];
    arr[a]=arr[b];
    arr[b]=temp;
};
const partition=(arr,k,left,right)=>{
    let pivot=left,lessThan=left;
    if(left===right) return left;
    for(let i=left;i<=right;i++){
        if(arr[i]<arr[pivot]){
            lessThan++;
            swap(arr,lessThan,i);
        }
    }
    swap(arr,lessThan,pivot);
    return lessThan;
};
const quickSelect=(arr,k,left,right)=>{
    let idx=partition(arr,k,left,right);
    // 一定要注意:idx已经被检查过所以不能再将其加入重新进行检查
    if(right-idx+1>k){
        return quickSelect(arr,k,idx+1,right);
    }else if(right-idx+1===k){
        return arr[idx];
    }else{
        return quickSelect(arr,k-(right-idx+1),left,idx-1);
    }
};
const topK=(arr,k)=>{
    return quickSelect(arr,k,0,arr.length-1);
};


summary

quickSelect方法 与Quick sort不同的是,Quick select只考虑所寻找的目标所在的那一部分子数组,而非像Quick sort一样分别再对两边进行分 割。正是因为如此,Quick select将平均时间复杂度从O(nlogn)降到了O(n),但与此同时,QuickSelect与QuickSort一样,是一个不稳定的算法;pivot选取直接影响了算法的好坏,worst case下的时间复杂度达到了O(n^2)

  • 时间复杂度最大为:O(n^2)
  • 时间复杂度最小为:O(n)
  • 空间复杂度:O(1)

二分法

关于二分查找法的详解可以参考我的另一篇文章:详解二分查找法

假如 N 个数中最大值是 V m a x V_{max} Vmax 最小值是 V m i n V_{min} Vmin,那么要求的第 k 大的值一定在 [ V m a x V_{max} Vmax , V m i n V_{min} Vmin] ,我们可以利用二分查找对比原数组中 > = m i d >=mid >=mid 的值的个数,利用求 target 右边界的方式求得答案。

解题步骤

  • 找到最大值max和最小值min
  • 计算 mid 值开始循环查找
    • 注意我们的max可以等于min,因为max min均为数组中的值
    • 如果比mid大的数目大于等于k,那么mid过小,min变成mid+1,即求 bisect_right
      • 为什么要把等于k的情况加入呢,因为等于k的情况下,选取的mid是小于等于目标值的
    • 如果比mid大的数目小于k,那么mid过大,max变成mid-1
  • 返回 min 值即为我们想要的

具体实现

const topK2=(arr,k)=>{
    let low=Number.MAX_SAFE_INTEGER,high=-Number.MAX_SAFE_INTEGER;
    for(let i=0;i<arr.length;i++){
        if(arr[i]<low){
            low=arr[i];
        }
        if(arr[i]>high){
            high=arr[i];
        }
    }
    // count用来计算数组中大于等于 mid 的个数
    let cn;
    while(low<=high){
        let mid=Math.floor((low+high)/2);
        cn=count(arr,mid);
        // 这里其实是求 target 右边界
        if(cn>=k){
            // mid is too small
            low=mid+1;
        }else{
            // mid is too big
            // 这里可以做层优化,缩减 k 的值,让 count 在 mid 和 high 中寻找大于等于 mid 的数目
            high=mid-1;
        }
    }
    return right;
};




// ========= 扩展 =========
/**
 * 如何改造成求最小的第k个值呢,使用求左边界的方式
 * @param arr
 * @param k
 * @returns {number}
 */
const topK3=(arr,k)=>{
    const count=(arr,target)=>{
        let count=0;
        for(let i=0;i<arr.length;i++){
            if(arr[i]<=target){
                count++;
            }
        }
        return count;
    };
    let low=Number.MAX_SAFE_INTEGER,high=-Number.MAX_SAFE_INTEGER;
    for(let i=0;i<arr.length;i++){
        if(arr[i]<low){
            low=arr[i];
        }
        if(arr[i]>high){
            high=arr[i];
        }
    }
    // count() 用来计算原数组中小于等于 target 的数量
    let cn;
    while(low<=high){
        let mid=Math.floor((low+high)/2);
        cn=count(arr,mid);
        if(cn>=k){
            // mid is too big
            high=mid-1;
        }else{
            // mid is too small
            low=mid+1;
        }
    }
    return low;
};

summary

  • 时间复杂度:O(Nlog(max_val-min_val))
  • 空间复杂度:O(1)

leetcode-719-找出第 k 小的距离对

首先这个题就是典型的topk问题,但是用构建最大堆的方式并不能通过,因此这边可以考虑用二分法的方式解决,既然采用二分法套用上面模板即可,但是难点在于如何遍历所有找出来小于mid的所有距离对的个数呢?(遍历所有显然不现实)

此处查找count数目用到了双指针:

  • 维护一个right 递增
  • 查找一个最小的left满足arr[right]-arr[left]<=mid
  • 叠加right-left的值即为小于mid的所有距离对个数。(至于为何是right-left可以通过自己举例验证比如1到4之间有多少距离对:3+2+1)
const count=(arr,target)=>{
    let res=0,left=0;
    for(let right=0;right<arr.length;right++){
        while(arr[right]-arr[left]>target)left++;
        res+=right-left;
    }
    return res;
};
const smallestDistancePair1=(arr,k)=>{
    arr.sort((a,b)=>a-b);
    let min=0,max=arr[arr.length-1]-arr[0];
    while(min<=max){
        let mid=Math.floor((min+max)/2);
        let cn=count(arr,mid);
        if(cn>=k){
            // mid过大,===k的情况时有可能取的mid也是一个大于目标值的数
            max=mid-1;
        }else{
            min=mid+1;
        }
    }
    return min;
};

leetcode-378-有序矩阵中第K小的元素

解题思路可以参考我的题解二分法解决or最小堆解决,也可以参考下文

具体可以参考我的另一篇文章:数据结构javascript描述中对于heap的总结。

假设我们已经构造了一个存有 k 个元素的小根堆,根节点元素就是第 k 大元素也就是后 k 个元素中最小的元素,假如后面继续遍历,有个元素 a 比 根节点还小,假设依然成立,如果比它大那么假设就不成立了,此时将根节点换成 a 并重新 heapify,如此直到遍历完所有元素,得到真正的后 k 个元素组成的小根堆。

解题步骤

  • 首先把数组中的前k个值用来建一个小根堆
  • 之后的其他数拿到之后进行判断,如果遇到比小顶堆的堆顶的值大的,将堆顶元素替换为该元素,并重新 heapify
  • 最终的堆会是数组中最大的 K 个数组成的结构,小根堆的顶部又是结构中的最小数,因此把堆顶的值弹出即可得到Top-K。

Tips:

Topk 问题就用小根堆解决,Lowk 问题就用大根堆解决。

具体实现

// 注意如果求的不是 topK 而是 lowK 则是用 大根堆
import Heap from './algorithm/Heap/MinHeap';
const topK=(arr,k)=>{
    let h=new Heap();
    for(let i=0;i<k;i++){
        h.insert(arr[i]);
    }
    for(let i=k;i<arr.length;i++){
        if (arr[i]>h.data[0]){
            h.deleting();
            h.insert(arr[i]);
        }
    }
    return h.data[0];
};

summary

  • 时间复杂度: O(NlogK)其中N为数组全部长度,K即为要求的K(因为堆中元素的数目永远是K
  • 空间复杂度:O(K)

经典例题

要使用哪种方法解决问题需要根据实际题目做出选择:

leetcode-719. Find K-th Smallest Pair Distance

最佳方法:二分法。其他方法也可以,但复杂度过高。

其他技巧:灵活应用双指针解决问题

class Solution:
      def smallestDistancePair(self, nums: List[int], k: int) -> int:
        """
        优化:count 时因为数组是有序的,除了二分查找还可以使用双指针
        时间复杂度:O(n*logD) 其中 D=max(nums)-min(nums)
        空间复杂度:O(logn) 主要是排序占据的
        :param nums:
        :param k:
        :return:
        """
        nums.sort()
        length = len(nums)

        def count(d: int):
            """
            寻找距离小于等于 d 的个数
            :param d:
            :return:
            """
            res = 0
            j = 0
            for i in range(length):
                while nums[i] - nums[j] > d:
                    j += 1
                res += i - j
            return res

        def my_bisect_left(l: int, r: int, target: int, f: Callable) -> int:
            """
            返回 target 位于原数组最左侧可以插入的位置
            :param l:
            :param r:
            :param target:
            :param f:
            :return:
            """
            while l <= r:
                mid = (l + r) // 2
                if f(mid) < target:
                    l = mid + 1
                else:
                    r = mid - 1
            return l

        max_val = nums[-1] - nums[0]
        # Tip: 因为是要求第 k 个,所以要按小于等于 k 来算
        return my_bisect_left(0, max_val, k, count)

leetcode-373. Find K Pairs with Smallest Sums

可以使用二分法解决,但题目找前 k 个而不是第 k 个,因此需要特殊处理一下

根据题目规律,利用小根堆解决问题

class Solution:
    def kSmallestPairs_bisect(self, nums1: List[int], nums2: List[int], k: int) -> List[List[int]]:
        """
        首先读懂题意,从 nums1 nums2 中各取一值,取和排序为前 k 个的索引对。
        方法一:穷举所有可能的组合,排序取前 k 个即可,时间复杂度:O(mnlog(mn))
        方法二:已知最小值,最大值,可以用二分法求 lowk 的方式,
        时间复杂度:O((m+n)log(max-min))
        空间复杂度:O(k)
        :param nums1:
        :param nums2:
        :param k:
        :return:
        """
        m, n = len(nums1), len(nums2)
        min_val, max_val = nums1[0] + nums2[0], nums1[-1] + nums2[-1]

        def count(target: int) -> int:
            # [1,7,11]
            # [2,4,6] 3
            # 3+17  <=10 4
            # 3+9  <=6 2
            # 7+9 <=7 3
            j = n - 1
            count = 0
            for i in range(m):
                while j >= 0 and nums1[i] + nums2[j] > target:
                    j -= 1
                count += j + 1
            return count

        while min_val <= max_val:
            mid = (min_val + max_val) // 2
            if count(mid) < k:
                min_val = mid + 1
            else:
                max_val = mid - 1
        # min_val 即为所求 lowk
        # print(min_val)
        # 但由于 min_val 不仅仅可能是 lowk 也可能是 lowk+1 lowk+2 针对这种情景就需要处理 equal 的场景
        res, equal = [], []
        j = n - 1
        for i in range(m):
            while j >= 0 and nums1[i] + nums2[j] > min_val:
                j -= 1
            for x in range(j + 1):
                if nums1[i] + nums2[x] == min_val:
                    equal.append([nums1[i], nums2[x]])
                else:
                    res.append([nums1[i], nums2[x]])
        if len(res) < k:
            res.extend(equal[0:k - len(res)])
        return res

    def heapify(self, nums: List[List[int]], length: int, i: int):
        """
        从 i 开始构建最小堆
        :param i:
        :param length:
        :param nums:
        :return:
        """
        l, r = i * 2 + 1, i * 2 + 2
        min_idx = i
        if l < length and nums[l][2] < nums[min_idx][2]:
            min_idx = l
        if r < length and nums[r][2] < nums[min_idx][2]:
            min_idx = r
        if min_idx != i:
            nums[i], nums[min_idx] = nums[min_idx], nums[i]
            self.heapify(nums, length, min_idx)

    def build_min_heap(self, nums: List[List[int]]):
        length = len(nums)
        # 最后一个非叶子节点所在索引
        idx = length // 2 - 1
        for i in range(idx, -1, -1):
            self.heapify(nums, length, i)

    def heap_extract_min(self, nums: List[List[int]]) -> List[int]:
        min_val = nums[0]
        length = len(nums)
        nums[0], nums[length - 1] = nums[length - 1], nums[0]
        nums.pop()
        self.heapify(nums, length - 1, 0)
        return min_val

    def heap_decrease(self, nums: List[List[int]], i: int, val: List[int]):
        """
        索引 i 处 修改为 val
        :param nums:
        :param i:
        :param val:
        :return:
        """

        # 根节点
        def get_root(idx: int) -> int:
            return (idx + 1) // 2 - 1

        nums[i] = val
        while get_root(i) >= 0 and nums[get_root(i)][2] > nums[i][2]:
            nums[get_root(i)], nums[i] = nums[i], nums[get_root(i)]
            i = get_root(i)

    def heap_push(self, nums: List[List[int]], val: List[int]):
        nums.append(val)
        length = len(nums)
        self.heap_decrease(nums, length - 1, val)

    def kSmallestPairs(self, nums1: List[int], nums2: List[int], k: int) -> List[List[int]]:
        """
        方法三:先选 k 个值构建大根堆,然后遍历所有元素将其与 heap[0] 对比,得到答案,但这种方式需要遍历 mn 个元素,如果 nums1 nums2
        数组较大,显然效率较低。但我们可以根据题目的规律构建小根堆来解决问题:
        已知最小的是 (0,0), 下一个就是待比较的就是 (0,1) 和 (1,0)
        假如下一个是 (0,1) 那么下一个要比较的是 (1,0) (1,1) (0,2) => (1,0) (0,2)
        假如下一个是 (1,0) 那么下一个要比较的是 (0,1) (1,1) (2,0) => (0,1) (2,0)

        假如下一个是 (1,0) 那么下一个要比较的是 (0,2) (1,1) (2,0)
        其实就是上次比较的数,加上新的 (a+1,b) (a,b+1),但是每次都这么增加其实带有重复的情况,比如
        (0,0)
        (0,1) (1,0)
        选 (0,1) 则 (1,0) 再加入 (1,1) (0,2)
        选 (1,0) 则 (1,1) (0,2) 再加入 (1,1) (2,0) 此时重复了 (1,1)
        如果一开始我们就有 (0,0) (1,0) (2,0)...(k-1,0), 找到最小值 (a,b) 之后每次都添加 (a,b+1) 则变成了
        (0,0)
        (1,0)... 再加入 (0,1)
        选 (0,1) 则 (1,0) (2,0) ... 再加入 (0,2)
        选 (1,0) 则 (2,0) (0,2) ... 再加入 (1,1) 满足需求
        建堆:O(N)
        extract_min: O(logN)
        heap_push: O(logN)
        时间复杂度:O(klogk)
        空间复杂度:O(k)
        :param nums1:
        :param nums2:
        :param k:
        :return:
        """
        m, n = len(nums1), len(nums2)
        nums = [[i, 0, nums1[i] + nums2[0]] for i in range(min(k, m))]
        self.build_min_heap(nums)
        res = []
        while nums and len(res) < k:
            [i, j, _] = self.heap_extract_min(nums)
            res.append([nums1[i], nums2[j]])
            if j + 1 < n:
                self.heap_push(nums, [i, j + 1, nums1[i] + nums2[j + 1]])
        return res

leetcode-378. Kth Smallest Element in a Sorted Matrix

最佳方法是使用二分法获得 lowk

其他技巧:z 形搜索(主要针对横纵均有序的矩阵,利用第一行最后一个元素为基准进行搜索的算法)

class Solution:
    def kthSmallest_0(self, matrix: List[List[int]], k: int) -> int:
        """
        首先读懂题意,已知矩阵中每一行和列是增序的,求矩阵中所有元素的第 k 小的元素
        方法一:直接对 n*n 个数字进行排序,时间复杂度 n*n*log(n*n) = O(2n^2logn)
        方法二:quick select 期望时间复杂度为 O(n^2)
        方法三:由于矩阵中元素其实是有序的,可以考虑使用二分法,得到 max_val, min_val,得到 mid 求小于等于 mid 的值的数量
        时间复杂度:O(Nlog(max_val-min_val))
        空间复杂度:O(1)
        :param matrix:
        :param k:
        :return:
        """
        n = len(matrix)
        min_val, max_val = matrix[0][0], matrix[n - 1][n - 1]

        def count(target: int, n: int) -> int:
            """
            寻找小于等于 target 的个数
            eg: 在 [[1,5,9],[10,11,13],[12,13,15]] 中寻找小于等于 8 的个数
            1 5 9
            10 11 13
            12 13 15
            如果使用 z 字形搜索,l=0,r=n-1,选这个点作为对比点,比它大的值肯定在下一行,比它小的则可以继续在该行左移,其他的所有点都类似,
            因此可以利用 z 字形搜索找到要找的答案。
            时间复杂度:O(2n) 即 O(n)
            空间复杂度:O(1)
            <= 12 的也利用 z 字形搜索,有 3+2+1=6
            <=14,有 3+3+2=8 此时要缩减一下
            <=13,有 3+3+2=8 此时再缩减一下,13+12//2=12
            <=12,有 3+2+1=6 left=mid+1,变成 13,得到结果
            :param target:
            :return:
            """
            left, right = 0, n - 1
            res = 0
            while left < n and right >= 0:
                if matrix[left][right] <= target:
                    res += right + 1
                    left += 1
                else:
                    right -= 1
            return res

        while min_val <= max_val:
            mid = (min_val + max_val) // 2
            # 注意这里是找 bisect_left
            if count(mid, n) < k:
                min_val = mid + 1
            else:
                max_val = mid - 1
        return min_val

    def heapify(self, nums: List[int], length: int, idx: int):
        """
        max-heap heapify
        :param nums:
        :param length:
        :param idx:
        :return:
        """
        max_idx = idx
        l = idx * 2 + 1
        r = idx * 2 + 2
        if l < length and nums[l] > nums[max_idx]:
            max_idx = l
        if r < length and nums[r] > nums[max_idx]:
            max_idx = r
        if max_idx != idx:
            nums[max_idx], nums[idx] = nums[idx], nums[max_idx]
            self.heapify(nums, length, max_idx)

    def build_max_heap(self, nums: List[int]):
        length = len(nums)
        last = length // 2 - 1
        for i in range(last, -1, -1):
            self.heapify(nums, length, i)

    def heap_change_val(self, nums: List[int], idx: int, val: int):
        nums[idx] = val
        if idx == 0:
            self.heapify(nums, len(nums), 0)

    def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
        """
        方法四:构建大根堆,审查所有元素,最后得到的根即为第 k 小的元素,时间复杂度 O(N^2logk)
        :param matrix:
        :param k:
        :return:
        """
        n = len(matrix)
        length = min(k, n * n)
        nums = []
        total = 0
        stop_i, stop_j = 0, 0
        for i in range(n):
            for j in range(n):
                if total >= length:
                    break
                nums.append(matrix[i][j])
                stop_i, stop_j = i, j
                total += 1
        # 构建 max-heap 时间复杂度 O(k)
        self.build_max_heap(nums)
        # print('====>', nums, stop_i, stop_j)
        for j in range(stop_j+1, n):
            if nums[0] > matrix[stop_i][j]:
                self.heap_change_val(nums, 0, matrix[stop_i][j])
        for i in range(stop_i + 1, n):
            for j in range(n):
                if nums[0] > matrix[i][j]:
                    self.heap_change_val(nums, 0, matrix[i][j])
        return nums[0]

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值