优先队列使用堆来实现,所以我们先学习堆。
堆的定义
堆是一种特殊的树。堆总是满足以下定义:
- 堆中的某个结点的值总是大于或小于其父结点的值(小根堆或大顶堆)。
- 堆总是一棵完全二叉树
完全二叉树
完全二叉树:除最底层外,其他层必须满结点,且最底层的结点必须从左到右紧凑排列的二叉树。
假如给一棵完全二叉树的结点按从上到下、从左到右顺序编号(序号从1开始),则:
- 编号为i的结点的父节点编号为 int(i/2)
- 编号为i的结点的左孩子编号为2i,右孩子编号为2i+1
优先队列的实现(C++)
优先队列使用堆来实现,而堆是一棵完全二叉树,因为完全二叉树结点较为紧凑,通常使用数组来储存。
由于堆中结点的编号习惯性从1开始,所以第一个结点从数组的下标为1的位置开始储存,下标0的位置闲置不用。
优先队列内部需要维护三个变量,用于储存堆的数组、已用空间size、数组容量capacity:
class PriorityQueue
{
int *data;
int size;
int capacity;
};
下面以大根堆为例,实现优先队列的入队、出队操作。
入队操作的实现
入队操作的实现思路如下:
首先将需要入队的元素放到数组的末尾,即堆中最底层、最右边的位置,然后根据堆的定义,需要将新入队的元素执行ShiftUp操作。
代码:
void add(int data)
{
//当数组无剩余空间时终止
assert(size < capacity);
//由于下标0空置,size指向最后一个元素,需先自增再赋值
this->data[++size] = data;
//对新入队的元素执行shiftup
shift_up(size);
}
ShiftUp
ShiftUp是指为保持堆的定义,将位于底部的元素逐渐上升的过程。此处涉及到上文的完全二叉树性质一,即编号为i的结点的父节点编号为int(i/2),这里的int(i/2)是指整形除法。ShiftUp实现代码:
//i为需要执行ShiftUp操作的元素下标
void shift_up(int i)
{
//当i不是根结点 且 i的父结点小于i时
while (i > 1 && data[i / 2] < data[i])
{
//交换结点i与其父结点
swap(data[i], data[i / 2]);
//由于上一句交换了结点,这里更新i
i /= 2;
}
}
值得一提的是,这种连续的swap操作可以优化成赋值操作,思想类似于插入排序的优化,优化后的代码如下;
void shift_up_opt(int i)
{
//保存初始的data[i]的值
int tmp = data[i];
//储存最后需要将tmp赋值过去的下标
int target = 0;
while (i > 1 && data[i / 2] < tmp)
{
data[i] = data[i / 2];
//更新target
target = i / 2;
i /= 2;
}
//赋值
data[target] = tmp;
}
出队操作的实现
思路:每次出队根节点,然后将最末尾的结点移到根节点的位置,再对移动过来的根节点(原最末尾结点)执行ShiftDown操作。
代码:
int extractMax()
{
//当数组为空时终止
assert(size > 0);
//暂存返回值
int ret = data[1];
//将最末尾结点移到根结点位置并size-1
data[1] = data[size--];
//对根结点执行shiftdown
shift_down(1);
return ret;
}
ShiftDown
与ShiftUp操作对应,ShiftDown指为维持堆的定义,将根节点元素逐渐下移的过程。
代码:
void shift_down(int i)
{
//swapNode为需要交换的结点下标
int swapNode = 0;
//当结点i有左孩子结点时循环
while (2 * i <= size)
{
//swapNode指向左孩子
swapNode = 2 * i;
//当结点i有右孩子 且 右孩子比左孩子大时,swapNode指向右孩子
if (swapNode + 1 <= size && data[swapNode + 1] > data[swapNode])
++swapNode;
//如结点i已经比最大的孩子结点大,则跳出循环
if (data[i] > data[swapNode])
break;
//否则交换swapNode结点与结点i
swap(data[i], data[swapNode]);
//由于上行代码已经交换结点,更新i
i = swapNode;
}
}
出队入队操作的时间复杂度
皆为O(log(n))。
完整实现代码
#include <iostream>
#include <cassert>
using namespace std;
class PriorityQueue
{
int *data;
int size;
int capacity;
//i -> index
void shift_up(int i)
{
while (i > 1 && data[i / 2] < data[i])
{
swap(data[i], data[i / 2]);
i /= 2;
}
}
void shift_up_opt(int i)
{
int tmp = data[i];
int target = 0;
while (i > 1 && data[i / 2] < tmp)
{
data[i] = data[i / 2];
target = i / 2;
i /= 2;
}
data[target] = tmp;
}
//i -> index
void shift_down(int i)
{
int swapNode = 0;
while (2 * i <= size)
{
swapNode = 2 * i;
if (swapNode + 1 <= size && data[swapNode + 1] > data[swapNode])
++swapNode;
if (data[i] > data[swapNode])
break;
swap(data[i], data[swapNode]);
i = swapNode;
}
}
public:
PriorityQueue(int capacity)
{
data = new int[capacity + 1];
size = 0;
this->capacity = capacity;
}
~PriorityQueue()
{
delete[] data;
}
void print()
{
for (int i = 1; i <= size; ++i)
cout << data[i] << " ";
cout << endl;
}
void add(int data)
{
//stop when array is full
assert(size < capacity);
this->data[++size] = data;
shift_up(size);
}
int extractMax()
{
//stop when array is not empty
assert(size > 0);
int ret = data[1];
data[1] = data[size--];
shift_down(1);
return ret;
}
};
int main()
{
PriorityQueue pq(10);
for (int i = 0; i < 7; ++i)
pq.add(i);
pq.print();
for (int i = 0; i < 7; ++i)
cout << pq.extractMax() << " ";
cout << endl;
return 0;
}