杨氏矩阵找第N大(小)的O(N)线性算法 LeetCode 378. Kth Smallest Element in a Sorted Matrix 373. Find K Pairs

265 篇文章 1 订阅

杨氏矩阵:一个N*N的矩阵,它的每行每列都单调递增(或者宽松一些,单调不减),即a[i][j]<=a[i+1][j], a[i][j]<=a[i][j+1]。

遇到的两道面试题:
1. 输出杨氏矩阵中最小的N个数。
2. 两个升序数组A和B,长度都是N。从两个数组中分别取出一个数,相加得到一个和。求这N*N个和的前N小。

本质上第2题可以转化成第一题:把A[0]+B[k]的结果填入矩阵第一行,A[1]+B[k]的结果填入第二行……就得到一个杨氏矩阵。所以现在就只考虑第1题咯。

此题常见的一种做法:N路归并,用一个大小为N的堆,可以O(NlgN)得到解。但是利用杨氏矩阵的性质,这题是有O(N)的算法的……

(为了方便,把矩阵记为num)

首先根据杨氏矩阵的性质得到最关键的一点:前N小的数,肯定不大于矩阵中的num[sqrt(N)+1][sqrt(N)+1]。

(为了方便,令M = sqrt(N)+1)

因此要找的前N小的数,肯定在矩阵的前M行和前M列中。

所以,要找的是一个M*N的矩阵和另外一个(N-M)*M的矩阵。

这样的规模相当于M*(2N-M),相当于M*N

这样问题可以转化为:在M个长度为N的有序数组中,查找前N小的数。(*)

除了之前提到的方法,此题还有一个比较容易想到的方法:二分上界并计数。在INT_MIN~INT_MAX中二分第K大数的上界,每次对所有数组二分统计其中不大于上界的数的个数。总体的复杂度是O(R*M*lgN)。其中R是最坏情况下二分的次数。对于32位整数,最多二分32次,R=32。但是对于浮点数,需要的二分次数会增多。

在这个思路的基础上加以改进,把R改进为lgN,便可得到线性算法。

基本思路是,把二分时取数的范围从INT_MIN~INT_MAX缩小到这MN个数中。每次从这些数中选一个,来作为计数的上界。

选的方法:

每一轮计数时,先找出这M个数组的中位数,作为每个数组潜在的切分点,然后选择这些切分点的中位数作为上界。O(M)选出M个切分点,O(MlgM)把这些数排个序再选中间的,所以这一步可以O(MlgM)(注:无序数组选中位数有均摊O(M)的算法)。

但是为了每一轮都能缩小查找范围,所以对于每个数组,还要维护一个“潜在切分点的可能区间”,选择该数组的新切分点时,取这个区间的中位数。实际上就是对每个数组,维护一个二分切分点的过程信息。

这样一轮统计过后,某些偏大(或偏小)的切分点所在区间长度就需要减半。并且,至少有半数区间的长度是要减半的。(对于一个数组,不大于中位数的数的个数至少是一半。“不小于”同理)

由于所有数组一共有MN个数,因此在lg(MN)轮后,所有区间长度都会减到1。

整理一下复杂度。一共要进行lg(MN)次计数;每次计数需要O(MlgM)找切分点的中位数,以及O(MlgN)对一个数组计数。因此整体的复杂度是:
O(lg(MN)*(MlgM+MlgN)) = O(sqrt(N)*(lgN)^2) = o(sqrt(N)*sqrt(N)) = o(N)
ps. 所以这个算法复杂度其实是低于O(N)的.

(*)附代码:用以上算法实现在M个有序数组中,查找第K小的数。

转自: 

http://wolf5x.cc/blog/algorithm/young-tableau-smallest-kth#comment-123

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

// i, partition_point_lower_index_i, partition_point_upper_index_i
typedef pair<int, pair<int,int> > PartRange;
typedef vector<vector<int> > VVI;

class PartComparator {
  const VVI &ary;
public:
  PartComparator(const VVI &a): ary(a){}
  bool operator()(const PartRange &x, const PartRange &y) const
  {
    return ary[x.first][(x.second.first+x.second.second)/2]
    < ary[y.first][(y.second.first+y.second.second)/2];
  }
};

// Get the count of numbers less than or equal to upper
int getCount(VVI &num, int upper) {
  int ret = 0;
  for (int i = 0; i < num.size(); i++) {
    ret += upper_bound(num[i].begin(), num[i].end(), upper) - num[i].begin();
  }
  return ret;
}

int chooseKthSmallest(VVI num, int k) {
  int n = num.size();
  vector<PartRange> part(n);
  for (int i = 0; i < n; i++) {
    part[i] = make_pair(i, make_pair(0, num[i].size()-1));
  }
  int ans = 1<<30; // INT_MAX;
  while(part.size() > 0) {
    // sort all the medians
    sort(part.begin(), part.end(), PartComparator(num));
    // choose the median of medians
    int mid = part.size()/2;
    int upper = num[part[mid].first][(part[mid].second.first+part[mid].second.second)/2];
    int count = getCount(num, upper);
    if (count >= k) {
      // update answer
      ans = min(ans, upper);
      // halve the median intervals of which the median is too large
      for(int i = 0; i < part.size(); i++) {
        int mid = (part[i].second.first+part[i].second.second)/2;
        if (num[part[i].first][mid] >= upper) {
          part[i].second.second = mid-1;
        }
      }
    } else {
      // halve the median intervals of which the median is too small
      for (int i = 0; i < part.size(); i++) {
        int mid = (part[i].second.first+part[i].second.second)/2;
        if (num[part[i].first][mid] <= upper) {
          part[i].second.first = mid+1;
        }
      }
    }
    // remove the empty median intervals
    for (int i = part.size()-1; i >= 0; i--) {
      if (part[i].second.first > part[i].second.second) {
        swap(part[i], part[part.size()-1]);
        part.erase(part.end()-1);
      }
    }
  }
  return ans;
}
int main() {
  int v[][3] = {{1,2,3},{2,3,4},{3,4,5}};
  vector<int> vec0(v[0],v[0]+3); 
  vector<int> vec1(v[1],v[1]+3); 
  vector<int> vec2(v[2],v[2]+3); 
  int arr[] = {1,2,2,3,4,4,4,4,5,6,7,8,9,9,10};
  vector<int> vec3(arr, arr+sizeof(arr)/sizeof(int));

  VVI num;
  num.push_back(vec0);
  num.push_back(vec1);
  num.push_back(vec2);
  int up = distance(vec3.begin(), upper_bound(vec3.begin(), vec3.end(), 11));
  int low = distance(vec3.begin(), lower_bound(vec3.begin(), vec3.end(), 11));
  
  int res = chooseKthSmallest(num, 3);

  return 0;
}

--------------------------------------------------------------------------------------------------------

很久之后发现一种更好的解法,仍然二分,假设数据二分的范围是整个整数,那么log(2^32)次最多是32次,实际中范围可以由左上角和右下角来确定,anyway,二分的次数可以看做是一个常数。。。

剩下的问题就是给定一个target,求这个matrix中<=target的元素个数,这个是可以O(n)实现的,也就是从右上角开始往下找。。所以整体复杂度是O(n)。LeetCode上恰好有这么一道题:

Given a n x n matrix where each of the rows and columns are sorted in ascending order, find the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

Example:

matrix = [
   [ 1,  5,  9],
   [10, 11, 13],
   [12, 13, 15]
],
k = 8,

return 13.

Note:
You may assume k is always valid, 1 ≤ k ≤ n2.

-----------------------------------------------------------

class Solution:
    def count_lower(self, nums, target, n):
        j, res = n - 1, 0
        for i in range(n):
            while (j >= 0 and nums[i][j] > target):
                j -= 1
            res += (j + 1)
        return res

    def kthSmallest(self, matrix, k: int) -> int:
        n = len(matrix)
        left, right = matrix[0][0], matrix[n - 1][n - 1]
        while (left <= right):
            target = ((right - left) >> 1) + left
            lower = self.count_lower(matrix, target, n)
            if (lower < k):
                left = target + 1
            else:
                right = target - 1
        return left

------------------------------------------------

另外一道题,要求输出也排序,就只能用堆来玩了:

You are given two integer arrays nums1 and nums2sorted in ascending order and an integer k.

Define a pair (u, v) which consists of one element from the first array and one element from the second array.

Return the k pairs (u1, v1), (u2, v2), ..., (uk, vk) with the smallest sums.

Example 1:

Input: nums1 = [1,7,11], nums2 = [2,4,6], k = 3
Output: [[1,2],[1,4],[1,6]]
Explanation: The first 3 pairs are returned from the sequence: [1,2],[1,4],[1,6],[7,2],[7,4],[11,2],[7,6],[11,4],[11,6]

Example 2:

Input: nums1 = [1,1,2], nums2 = [1,2,3], k = 2
Output: [[1,1],[1,1]]
Explanation: The first 2 pairs are returned from the sequence: [1,1],[1,1],[1,2],[2,1],[1,2],[2,2],[1,3],[1,3],[2,3]

Example 3:

Input: nums1 = [1,2], nums2 = [3], k = 3
Output: [[1,3],[2,3]]
Explanation: All possible pairs are returned from the sequence: [1,3],[2,3]

Constraints:

  • 1 <= nums1.length, nums2.length <= 105
  • -109 <= nums1[i], nums2[i] <= 109
  • nums1 and nums2 both are sorted in ascending order.
  • 1 <= k <= 1000




from typing import List


import heapq
class Solution:
    def kSmallestPairs(self, nums1: List[int], nums2: List[int], k: int) -> List[List[int]]:
        l1, l2 = len(nums1), len(nums2)
        hq,res = [(nums1[0]+nums2[0],0,0)],[]
        while (hq and len(res) < k):
            csum,i,j = heapq.heappop(hq)

            res.append([nums1[i],nums2[j]])
            if (j+1<l2):
                heapq.heappush(hq,(nums1[i]+nums2[j+1],i,j+1))
            #列维度扩张由j说了算,只有列是0的时候考虑一下行维度的扩张,也控制了优先级
            if j == 0 and i+1<l1:
                heapq.heappush(hq,(nums1[i+1]+nums2[j],i+1,j))
        return res

s = Solution()
print(s.kSmallestPairs(nums1 = [1,2], nums2 = [3], k = 3))
print(s.kSmallestPairs(nums1 = [1,1,2], nums2 = [1,2,3], k = 2))
print(s.kSmallestPairs(nums1 = [1,7,11], nums2 = [2,4,6], k = 3))
#expected: [[0, -3], [0, -3], [0, -3], [0, -3], [0, -3], [0, 22], [0, 22], [0, 22], [0, 22], [0, 22], [0, 35], [0, 35], [0, 35], [0, 35], [0, 35], [0, 56], [0, 56], [0, 56], [0, 56], [0, 56], [0, 76], [0, 76]]
#          [[0, -3], [0, 22], [0, 35], [0, 56], [0, 76], [0, -3], [0, 22], [0, 35], [0, 56], [0, 76], [0, -3], [0, 22], [0, 35], [0, 56], [0, 76], [0, -3], [0, 22], [0, 35], [0, 56], [0, 76], [0, -3], [0, 22]]
#          [[0, -3], [0, -3], [0, -3], [0, -3], [0, -3], [0, 22], [0, 22], [0, 22], [0, 22], [0, 22], [0, 35], [0, 35], [0, 35], [0, 35], [0, 35], [0, 56], [0, 56], [0, 56], [0, 56], [0, 56], [0, 76], [0, 76]]
print(s.kSmallestPairs([0,0,0,0,0],[-3,22,35,56,76],22))

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值