在未排序的数组中找到第 k 个最大的元素。请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。
示例 1:
输入: [3,2,1,5,6,4] 和 k = 2
输出: 5
示例 2:
输入: [3,2,3,1,2,4,5,5,6] 和 k = 4
输出: 4
说明:
你可以假设 k 总是有效的,且 1 ≤ k ≤ 数组的长度。
思路与代码
快速排序
时间复杂度 O(nlogn)
小顶堆
时间复杂度 O(nlogk)
class Solution {
public int findKthLargest(int[] nums, int k) {
MinHeap minHeap = new MinHeap(k);
for (int num : nums) {
if (minHeap.size() == k && minHeap.peek() < num) minHeap.poll();
minHeap.insert(num);
}
return minHeap.peek();
}
class MinHeap{
private int capacity;
private int count;
private int[] elements;
public MinHeap(int capacity) {
this.capacity = capacity;
elements = new int[capacity];
}
public void insert(int num) {
if (count >= capacity) return;
elements[count] = num; // 新加入的元素放堆底部
// 从底部向上调整
int pos = count;
while (pos > 0 && elements[pos] < elements[(pos - 1) / 2]) {
swap(pos, (pos - 1) / 2);
pos = (pos - 1) / 2;
}
count++;
}
public void poll() {
if (count == 0) return;
//将堆底元素放入堆顶,相当于移除了堆顶元素
elements[0] = elements[--count];
// 从上向下进行调整
int pos = 0;
while (true) {
// 将当前节点值与左右节点的较小值进行交换
int targetPos = pos;
if (2 * pos + 1 < count && elements[targetPos] > elements[2 * pos + 1]) targetPos = 2 * pos + 1;
if (2 * pos + 2 < count && elements[targetPos] > elements[2 * pos + 2] ) targetPos = 2 * pos + 2;
if (targetPos == pos) break; // 说明已经调整完了
swap(pos, targetPos);
pos = targetPos;
}
}
public int peek() {
return elements[0];
}
public int size() {
return this.count;
}
public void swap(int a, int b) {
int tmp = elements[a];
elements[a] = elements[b];
elements[b] = tmp;
}
}
}
分治法 快速选择
时间复杂度 O(n)
时间复杂度 O(logn)
递归栈所占空间
class Solution {
Random random = new Random();
public int findKthLargest(int[] nums, int k) {
int index = nums.length - k;
return quickSelect(nums, 0, nums.length - 1, index);
}
public int quickSelect(int[] nums, int l, int r, int index) {
int p = partition(nums, l, r);
if (p == index) return nums[p];
return p > index ? quickSelect(nums, l, p - 1, index) : quickSelect(nums, p + 1, r, index);
}
public int partition(int[] nums, int l, int r) {
int tmp = random.nextInt(r - l + 1) + l;
int num = nums[tmp];
swap(nums, r, tmp);
int index = l;
for (int i = l; i <= r; i++) {
if (nums[i] <= num) {
swap(nums, i, index++);
}
}
return index - 1;
}
public void swap(int[] nums, int a, int b) {
int tmp = nums[a];
nums[a] = nums[b];
nums[b] = tmp;
}
}