第k大元素问题
在LintCode上刷了一道题,题目如下:
在数组中找到第 k 大的元素。
算法描述:
- 如果数组元素个数 ≤ 75 \le75 ≤75,对数组使用冒泡排序算法,直接找到第k个元素。否则按以下算法进行查找。
- 把数组按5个一组进行分组,共 ⌈ n 5 ⌉ \lceil \frac{n}{5} \rceil ⌈5n⌉组,对每个小组使用冒泡排序算法找到其中位数,把每个小组的中位数交换到数组的前面,形成一个 ⌈ n 5 ⌉ \lceil \frac{n}{5} \rceil ⌈5n⌉个元素的子集 S S S。
- 递归地在 S S S子集中找出其中位数 m m m。
- 使用划分算法,把大于 m m m的元素放到数组的左边,小于 m m m的元素放到数组的右边。划分完成后,得到元素 m m m在数组中的位置 m i m_i mi。
- 如果 k = m i k=m_i k=mi,则 m m m为问题的解。如果 k < m i k<m_i k<mi,递归地在 m i m_i mi的左边进行查找,如果 k > m i k>m_i k>mi,递归地在 m i m_i mi的右边进行查找。
算法时间复杂度分析
设总体的平均时间复杂度为 T ( n ) T(n) T(n)
- 对小组排序取中位数的平均时间复杂度为 f 1 ( n ) = 5 × 3 × n 5 = 3 n f_1(n) = 5 \times 3 \times \frac{n}{5}=3n f1(n)=5×3×5n=3n
- 对 S S S子集取中位数的平均时间复杂度为 f 2 ( n ) = T ( n 5 ) f_2(n) = T(\frac{n}{5}) f2(n)=T(5n)
3.划分算法的平均时间复杂度为 f 3 ( n ) = n f_3(n)=n f3(n)=n
4.在左边或右边查找的平均时间复杂度为 f 4 ( n ) = T ( n 2 ) f_4(n)=T(\frac{n}{2}) f4(n)=T(2n)
5.则
T ( n ) = 3 n + T ( n 5 ) + n + T ( n 2 ) T ( n ) ≈ 8 n = O ( n ) T(n) = 3n + T(\frac{n}{5}) + n + T(\frac{n}{2}) \\ T(n) \approx 8n =O(n) T(n)=3n+T(5n)+n+T(2n)T(n)≈8n=O(n)
代码如下:
class Solution {
public:
/**
* @param n: An integer
* @param nums: An array
* @return: the Kth largest element
*/
int kthLargestElement(int n, vector<int> &nums) {
int k = pick(nums, 0, nums.size() - 1, n - 1);
return nums[k];
}
int pick(vector<int> &a, int l, int h, int k) {
int kk = l + k;
if (h - l <= 74) {
if (k <= ((h - l) >> 1)) bobSort(a, l, h, k);
else reversBobSort(a, l, h, k);
return kk;
}
int s = group(a, l, h);
int m = pick(a, l, s - 1, (s - l) >> 1);
m = partition(a, l, h, m);
if (m == kk) return m;
if (m > kk) return pick(a, l, m - 1, k);
return pick(a, m + 1, h, kk - m - 1);
}
int group(vector<int> &a, int l, int h) {
int s = l, j, m;
for (int i = l; i <= h; i += 5) {
j = i + 4;
if (j > h) j = h;
bobSort(a, i, j, 2);
m = (i + j) >> 1;
swap(a, s, m);
s++;
}
return s;
}
int partition(vector<int> &a, int l, int h, int m) {
swap(a, l, m);
int t = a[l];
while (l < h) {
while (h > l && a[h] <= t) h--;
if (h > l) a[l++] = a[h];
while (l < h && a[l] >= t) l++;
if (l < h) a[h--] = a[l];
}
a[l] = t;
return l;
}
inline void swap(vector<int> &a, int i, int j) {
int t = a[i];
a[i] = a[j];
a[j] = t;
}
void bobSort(vector<int> &a, int l, int h, int k) {
int kk = l + k + 1;
for (int i = l + 1; i <= kk; i++) {
bool swp = false;
for (int j = h; j >= i; j--) {
if (a[j] > a[j - 1]) {
swap(a, j, j - 1);
swp = true;
}
}
if (!swp) break;
}
}
void reversBobSort(vector<int> &a, int l, int h, int k) {
int kk = l + k - 1;
for (int i = h - 1; i >= kk; i--) {
bool swp = false;
for (int j = l; j <= i; j++) {
if (a[j] < a[j + 1]) {
swap(a, j, j + 1);
swp = true;
}
}
if (!swp) break;
}
}
};
以下是提交结果
之前用java代码提交过几次,执行时间都在480ms左右,c++代码的执行时间大约是java的
1
10
\frac{1}{10}
101。
看到有很多人c++版本的执行时间为50ms,算法应该可能进一步优化,例如把递归实现改成无栈非递归实现。