这里我们选择第k个最大元素来做实例。第一反应是这个题目似乎很简单,直接遍历,每一次挑最大的元素放入数组,最后就可以得到我们想要的结果,但是这个时间复杂度太高, O(K * n)。当 K 是比较小的常量时,时间复杂度是 O(n);但当 K 等于 n/2 或者 n 时时间复杂度就是 O(n^2) 了。我们需要用到在快排中的划分区间的方法。
我们选择数组区间 A[0...n-1]的最后一个元素 A[n-1]作为 pivot,对数组 A[0...n-1]原地分区,这样数组就分成了三部分,A[0...p-1]、A[p]、A[p+1...n-1];
int partition(vector<int>& nums,int p, int r)
{
int pivot = nums[r];
int i = p;
for(int j = p;j <= r-1;j++)
{
if(nums[j] < pivot)
{
swap(nums[i], nums[j]);
i++;
}
}
swap(nums[i], nums[r]);
return i;
}
这个就是我们的核心代码,然后如果 q=n-k,这里的q是数组的下标,第K个最大元素就应该是n-k,那 A[q]就是要求解的元素;如果 n-k>q, 说明第 k 大元素出现在 A[q+1...n-1]区间,我们再按照上面的思路递归地在 A[q+1...n-1]这个区间内查找。代码如下:
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
int ans = -1;
int p = 0, r = n-1;
while(true)
{
int q = partition(nums, p, r);
if(q == n - k)
{
ans = nums[q];
break;
}
else if(q < n-k)
{
p = q+1;
}
else{
r = q-1;
}
}
return ans;
}
这里的时间复杂度就是O(n) 。
第一次分区查找,我们需要对大小为 n 的数组执行分区操作,需要遍历 n 个元素。第二次分区查找,我们只需要对大小为 n/2 的数组执行分区操作,需要遍历 n/2 个元素。依次类推n/2、n/4、n/8、n/16.……直到区间为 1。n+n/2+n/4+n/8+...+1。最后的和等于 2n-1。时间复杂度就为 O(n)。
大部分情况下时间复杂度都为O(n),如果全员顺序,那就是O(n);这里可以随机化分区函数中的pivot,方法也简单直接在开头选择pivot时,让r与(p + r)/ 2 交换就可以随机了。
如下:
int partition(vector<int>& nums,int p, int r)
{
if(r > p)
{
int ran = (p + r) / 2;
swap(nums[ran], nums[r]);
}
int pivot = nums[r];
int i = p;
for(int j = p;j <= r-1;j++)
{
if(nums[j] < pivot)
{
swap(nums[i], nums[j]);
i++;
}
}
swap(nums[i], nums[r]);
return i;
}
将前面的思路拼接起来就是下面的代码:
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
int ans = -1;
int p = 0, r = n-1;
while(true)
{
int q = partition(nums, p, r);
if(q == n - k)
{
ans = nums[q];
break;
}
else if(q < n-k)
{
p = q+1;
}
else{
r = q-1;
}
}
return ans;
}
int partition(vector<int>& nums,int p, int r)
{
if(r > p)
{
int ran = (p + r) / 2;
swap(nums[ran], nums[r]);
}
int pivot = nums[r];
int i = p;
for(int j = p;j <= r-1;j++)
{
if(nums[j] < pivot)
{
swap(nums[i], nums[j]);
i++;
}
}
swap(nums[i], nums[r]);
return i;
}
求第K个最小元素就是将q = k-1,其他不变;
int findKthLargest(vector<int>& nums, int k) {
int n = nums.size();
int ans = -1;
int p = 0, r = n-1;
while(true)
{
int q = partition(nums, p, r);
if(q == k - 1)
{
ans = nums[q];
break;
}
else if(q < k - 1)
{
p = q+1;
}
else{
r = q-1;
}
}
return ans;
}
int partition(vector<int>& nums,int p, int r)
{
if(r > p)
{
int ran = (p + r) / 2;
swap(nums[ran], nums[r]);
}
int pivot = nums[r];
int i = p;
for(int j = p;j <= r-1;j++)
{
if(nums[j] < pivot)
{
swap(nums[i], nums[j]);
i++;
}
}
swap(nums[i], nums[r]);
return i;
}
这样这个题目就完成了。