以数组存放,下标从0开始,大顶堆为例
1. 堆化(往堆中插入一个元素)
从下往上
/**
* @brief 往堆中插入新元素
* @param a 堆
* @param data 新元素
*/
void Insert(vector<int> &a, int data) {
a.push_back(data);
int i = a.size() - 1;
while (1) {
int parent = (i-1)/2;
if (parent >= 0 && a[i] > a[parent]) {
swap(a[i], a[parent]);
i = parent;
}
else {
break;
}
}
}
从上往下
/**
* @brief 自上往下堆化
* @param a 数据数组,原地算法,所以不能通过a.size()来判断当前的堆容量
* @param n 堆容量
* @param i 当前处理的下标位置
*/
void Heapify(vector<int> &a, int n, int i) {
while (1) {
int max_pos = i;
int left = 2 * i + 1;
int right = 2 * i + 2;
if (left < n && a[max_pos] < a[left]) max_pos = left;
if (right < n && a[max_pos] < a[right]) max_pos = right;
if (max_pos == i) break;
swap(a[i], a[max_pos]);
i = max_pos;
}
}
关键在于正确写成下标以及边界条件,堆化的时间复杂度和树的高度成正比,由于堆是完全二叉树,为O(logn)
2. 删除堆顶元素
/**
* @brief 从已经建好的堆中删除堆顶元素
* @param a 堆
*/
void RemoveHeapTop(vector<int> &a) {
if (a.size() == 0) return;
a[0] = a.back();
a.pop_back();
Heapify(a, a.size(), 0);
}
删除就是用最后一个元素覆盖堆顶元素,在进行堆化,注意Heapify是从上往下进行堆化的,时间复杂度和树的高度成正比,为O(logn)
3. 建堆
/**
* @brief 原地建堆
* @param a 原始数据数组
*/
void BuildHeap(vector<int> &a) {
for (int i = a.size()/2 - 1; i>=0; --i) {
Heapify(a, a.size(), i);
}
}
分析比较复杂,时间复杂度是O(n)
4. 堆排序
/**
* @brief 自上往下堆化
* @param a 数据数组,原地算法,所以不能通过a.size()来判断当前的堆容量
* @param n 堆容量
* @param i 当前处理的下标位置
*/
void Heapify(vector<int> &a, int n, int i) {
while (1) {
int max_pos = i;
int left = 2 * i + 1;
int right = 2 * i + 2;
if (left < n && a[max_pos] < a[left]) max_pos = left;
if (right < n && a[max_pos] < a[right]) max_pos = right;
if (max_pos == i) break;
swap(a[i], a[max_pos]);
i = max_pos;
}
}
/**
* @brief 原地建堆
* @param a 原始数据数组
*/
void BuildHeap(vector<int> &a) {
for (int i = a.size()/2 - 1; i>=0; --i) {
Heapify(a, a.size(), i);
}
}
/**
* @brief 实现堆排序
* @param a 原始数组
*/
void HeapSort(vector<int> &a) {
BuildHeap(a);
int k = a.size() - 1;
while (k > 0) {
swap(a[0],a[k]);
Heapify(a, k,0); //前k个元素堆化
--k;
}
}
时间复杂度O(nlogn),最好最坏平均都是这个,但是由于堆化时需要较好第一个元素和最后一个元素,因此堆排序不是稳定的排序算法
以上完整代码总结如下:
#include <iostream>
#include <vector>
using namespace std;
/**
* @brief 往堆中插入新元素
* @param a 堆
* @param data 新元素
*/
void Insert(vector<int> &a, int data) {
a.push_back(data);
int i = a.size() - 1;
while (1) {
int parent = (i-1)/2;
if (parent >= 0 && a[i] > a[parent]) {
swap(a[i], a[parent]);
i = parent;
}
else {
break;
}
}
}
/**
* @brief 自上往下堆化
* @param a 数据数组,原地算法,所以不能通过a.size()来判断当前的堆容量
* @param n 堆容量
* @param i 当前处理的下标位置
*/
void Heapify(vector<int> &a, int n, int i) {
while (1) {
int max_pos = i;
int left = 2 * i + 1;
int right = 2 * i + 2;
if (left < n && a[max_pos] < a[left]) max_pos = left;
if (right < n && a[max_pos] < a[right]) max_pos = right;
if (max_pos == i) break;
swap(a[i], a[max_pos]);
i = max_pos;
}
}
/**
* @brief 原地建堆
* @param a 原始数据数组
*/
void BuildHeap(vector<int> &a) {
for (int i = a.size()/2 - 1; i>=0; --i) {
Heapify(a, a.size(), i);
}
}
/**
* @brief 实现堆排序
* @param a 原始数组
*/
void HeapSort(vector<int> &a) {
BuildHeap(a);
int k = a.size() - 1;
while (k > 0) {
swap(a[0],a[k]);
Heapify(a, k,0); //前k个元素堆化
--k;
}
}
/**
* @brief 对已经建好的堆进行排序
* @param a 堆
*/
void HeapOnlySort(vector<int> &a) {
int k = a.size() - 1;
while (k > 0) {
swap(a[0],a[k]);
Heapify(a, k,0); //前k个元素堆化
--k;
}
}
/**
* @brief 从已经建好的堆中删除堆顶元素
* @param a 堆
*/
void RemoveHeapTop(vector<int> &a) {
if (a.size() == 0) return;
a[0] = a.back();
a.pop_back();
Heapify(a, a.size(), 0);
}
int main() {
vector<int> a{1,5,2,3,4,7,9,8,6};
//HeapSort(a);
BuildHeap(a);
Insert(a, 7);
Insert(a, 10);
Insert(a, -5);
RemoveHeapTop(a);
RemoveHeapTop(a);
HeapOnlySort(a);
for (auto &m : a) {
cout << m << " ";
}
return 0;
}
5. 合并多个有序链表
思路:维护一个小顶堆,首先将所有非空链表压入堆,然后每次弹出堆顶元素,若弹出元素存在下一个非空结点,则将下一个结点压入堆,直到堆为空,需要自己实现建堆,出堆,往堆中插入元素,编码较为复杂
/**
* Definition for singly-linked list.
* struct ListNode {
* int val;
* ListNode *next;
* ListNode() : val(0), next(nullptr) {}
* ListNode(int x) : val(x), next(nullptr) {}
* ListNode(int x, ListNode *next) : val(x), next(next) {}
* };
*/
class Solution {
private:
void heapify(vector<ListNode*> &heap, int n, int i) {
if (i < 0 || i>= n) return;
int min_pos, left, right;
while (1) {
min_pos = i;
left = 2*i + 1;
right = 2*i + 2;
if (left<n && heap[left]->val < heap[min_pos]->val) min_pos = left;
if (right<n && heap[right]->val < heap[min_pos]->val) min_pos = right;
if (min_pos == i) break;
swap(heap[min_pos], heap[i]);
i = min_pos;
}
}
void buildHeap(vector<ListNode*> &heap) {
int start = heap.size()/2 - 1;
for (int i=start; i>=0; --i) {
heapify(heap, heap.size(), i);
}
}
void removeTop(vector<ListNode*> &heap) {
if (0 == heap.size()) return;
heap[0] = heap.back();
heap.pop_back();
if (heap.size() < 2) return;
heapify(heap, heap.size(), 0);
}
void insert(vector<ListNode*> &heap, ListNode* node) {
heap.push_back(node);
if (heap.size() < 2) return;
int i = heap.size() - 1;
int parent;
while (1) {
parent = (i-1) / 2;
if (parent<0 || heap[parent]->val < heap[i]->val) break;
swap(heap[parent], heap[i]);
i = parent;
}
}
public:
ListNode* mergeKLists(vector<ListNode*>& lists) {
// 0.准备
ListNode* guard = new ListNode();
ListNode* curr = guard;
// 1.建堆,小顶堆
vector<ListNode*> heap;
for (auto &m : lists) {
if (m) {
heap.push_back(m);
}
}
buildHeap(heap);
// 2.循环处理
while (1) {
// 2.0 若堆为空,则退出
if (0 == heap.size()) break;
// 2.1 存放结果
curr->next = heap[0];
curr = curr->next;
// 2.2 若堆顶元素下一个结点非空,则替换堆顶元素为它下一个结点,再堆化
if (heap[0]->next) {
heap[0] = heap[0]->next;
heapify(heap, heap.size(), 0);
}
// 2.3 若堆顶元素下一个结点为空,则删除堆顶元素
else {
removeTop(heap);
}
}
// 3. 返回结果
return guard->next;
}
};
时间复杂度O(nlogk),其中n为所有结点总数,k为链表数量,空间复杂度O(k)
这道题可以使用c++自带的优先队列,priority_queue,底层实现就是堆,需要自行定义<操作符函数,默认是大顶堆
/**
* Definition for singly-linked list.
* struct ListNode {
* int val;
* ListNode *next;
* ListNode() : val(0), next(nullptr) {}
* ListNode(int x) : val(x), next(nullptr) {}
* ListNode(int x, ListNode *next) : val(x), next(next) {}
* };
*/
struct Status {
int val;
ListNode* ptr;
bool operator < (const Status &rhs) const {
return val > rhs.val;
}
Status(int v, ListNode* p) : val(v), ptr(p) {} //可以省略,使用默认初始化列表,但是写了更加清晰
};
class Solution {
public:
ListNode* mergeKLists(vector<ListNode*>& lists) {
// 1.初始化优先队列
priority_queue<Status> q;
for (auto &m : lists) {
if (m) {
q.push(Status(m->val, m));
}
}
// 2.循环处理
ListNode *guard = new ListNode();
ListNode *tail = guard;
while (!q.empty()) {
// 2.1 取出堆顶元素
tail->next = q.top().ptr;
tail = tail->next;
q.pop();
// 2.2 若堆顶元素存在下一个非空结点则压入堆,注意此时tail就是上一个堆顶元素对应的指针
if (tail->next) {
q.push(Status{tail->next->val, tail->next});
}
}
// 3.返回结果
return guard->next;
}
};
思路和上面基本一致,不过用c++自带的priority_queue来代替自己实现的堆,编码更加简单,不容易出错,不过注意这种方式需要定义堆的元素类型,并且需要定义<运算符函数,最好也定义构造函数
6. 利用堆求 Top K
思路:维护一个大顶堆,每次将小于堆顶的元素压入堆,然后弹出新的堆的堆顶,到最后堆中剩下的k个元素就是最小的k个元素
class Solution {
public:
vector<int> smallestK(vector<int>& arr, int k) {
vector<int> ans;
if (k == 0) return ans;
// 1.初始化优先队列,默认是大顶堆
priority_queue<int, vector<int>, less<>> q;
//priority_queue<int> q;
for (int i=0; i<k; ++i) {
q.push(arr[i]);
}
// 2.每次比较,若当前元素小于堆顶元素则压入堆,然后弹出堆顶元素
for (int i=k; i<arr.size(); ++i) {
if (arr[i] < q.top()) {
q.push(arr[i]);
q.pop();
}
}
// 3.最终堆中元素就是答案
while (!q.empty()) {
ans.push_back(q.top());
q.pop();
}
return ans;
}
};
时间复杂度O(nlogk),空间复杂度O(k)
7. 利用堆求中位数
思路:维护两个堆,一个大顶堆,一个小顶堆,当元素小于大顶堆堆顶时压入大顶堆,否则压入小顶堆,然后平衡两个堆使得小顶堆数量为sum/2,其中sum是总的元素个数;再求中位数时,如果元素总数为奇数,则返回大顶堆堆顶,若元素总数为偶数,则返回两个堆堆顶元素的平均值
class MedianFinder {
private:
priority_queue<int, vector<int>, less<>> max_top_heap; //大顶堆
priority_queue<int, vector<int>, greater<>> min_top_heap; //小顶堆
public:
/** initialize your data structure here. */
MedianFinder() {
}
void addNum(int num) {
// 1.入堆
if (0 == max_top_heap.size() || num < max_top_heap.top()) {
max_top_heap.push(num);
}
else {
min_top_heap.push(num);
}
// 2.平衡两个堆
int sum = min_top_heap.size() + max_top_heap.size();
int temp;
while (1) {
if (min_top_heap.size() == sum/2) break; //平衡了
else if (min_top_heap.size() < sum/2) {
temp = max_top_heap.top();
max_top_heap.pop();
min_top_heap.push(temp);
}
else {
temp = min_top_heap.top();
min_top_heap.pop();
max_top_heap.push(temp);
}
}
}
double findMedian() {
int sum = min_top_heap.size() + max_top_heap.size();
if (sum % 2 == 0) {
return ((double)max_top_heap.top() + min_top_heap.top()) / 2;
}
else {
return max_top_heap.top();
}
}
};
/**
* Your MedianFinder object will be instantiated and called as such:
* MedianFinder* obj = new MedianFinder();
* obj->addNum(num);
* double param_2 = obj->findMedian();
*/
时间复杂度:加入元素是O(log(n/2)),最大是log(n),空间复杂度为O(n),求中位数的时间复杂度为O(1)