题目:《程序员面试金典(第5版)》P342
随机生成一些数字并传入某个方法。编写一个程序,每当接收到新数字的时候,找出并记录中位数。
提示:用一个最大堆和一个最小堆,最大堆记录小于中位数的数,最小堆记录大于等于中位数的数。最大堆和最小堆的元素个数相差不超过1。函数GetNumber() 接收到一个数字时,与中位数对比,插入到最大堆或最小堆中。函数FindMiddleNumber() 根据最大堆和最小堆,用O(1)的时间找出中位数。
vector<int> minheap;
vector<int> maxheap;
void Swap(int &a, int &b)
{
int temp = a;
a = b;
b = temp;
}
//数组maxheap第n-1位(含)前的数字已经是最大堆,现在往堆中加入数组第n位的数据,使得堆的大小加1
bool MaxHeapFixup(vector<int> &maxheap, int n)
{
if (n >= maxheap.size())
return false;
//若父节点小于子节点,则交换两者数值
for (int i = (n - 1) / 2; i >= 0 && maxheap[i]<maxheap[n]; n = i, i = (n - 1) / 2)
Swap(maxheap[n], maxheap[i]);
return true;
}
bool MinHeapFixup(vector<int> &minheap, int n)
{
if (n >= minheap.size())
return false;
//若父节点大于子节点,则交换两者数值
for (int i = (n - 1) / 2; i >= 0 && minheap[i]>minheap[n]; n = i, i = (n - 1) / 2)
Swap(minheap[n], minheap[i]);
return true;
}
bool MaxHeapFixdown(vector<int> &maxheap, int i)
{
if (2 * i + 1 >= maxheap.size())//没有子节点,无需下沉
return false;
int child = 2 * i + 1;
while (child<maxheap.size())
{
if (child + 1<maxheap.size() && maxheap[child + 1]>maxheap[child])//找出最大的孩子节点
child++;
if (maxheap[i]<maxheap[child])
Swap(maxheap[i], maxheap[child]);
else
break;
i = child;
child = 2 * i + 1;
}
return true;
}
bool MinHeapFixdown(vector<int> &minheap, int i)
{
if (2 * i + 1 >= minheap.size())//没有子节点,无需下沉
return false;
int child = 2 * i + 1;
while (child<minheap.size())
{
if (child + 1<minheap.size() && minheap[child + 1]<minheap[child])//找出最小的孩子节点
child++;
if (minheap[i]>minheap[child])
Swap(minheap[i], minheap[child]);
else
break;
i = child;
child = 2 * i + 1;
}
return true;
}
bool MaxHeapDeleteNumber(vector<int> &maxheap)
{
if (maxheap.empty())
return false;
Swap(maxheap[0], maxheap[maxheap.size() - 1]);//交换头尾两个数
maxheap.erase(maxheap.end() - 1);//删除最后一个数
MaxHeapFixdown(maxheap, 0);
return true;
}
bool MinHeapDeleteNumber(vector<int> &minheap)
{
if (minheap.empty())
return false;
Swap(minheap[0], minheap[minheap.size() - 1]);//交换头尾两个数
minheap.erase(minheap.end() - 1);//删除最后一个数
MinHeapFixdown(minheap, 0);
return true;
}
//接收数据
void GetNumber(int n)
{
if (maxheap.empty() && minheap.empty())
minheap.push_back(n);
else if (maxheap.size() == minheap.size())
{
int mid =( maxheap[0] + minheap[0])/2;
if (n >= mid)
{
minheap.push_back(n);
MinHeapFixup(minheap, minheap.size() - 1);
}
else
{
maxheap.push_back(n);
MaxHeapFixup(minheap, maxheap.size() - 1);
}
}
else if (maxheap.size() > minheap.size())
{
int mid = maxheap[0];
if (n >= mid)
{
minheap.push_back(n);
MinHeapFixup(minheap, minheap.size() - 1);
}
else
{
maxheap.push_back(n);
MaxHeapFixup(maxheap, maxheap.size() - 1);
int tmp = maxheap[0];
MaxHeapDeleteNumber(maxheap);
minheap.push_back(tmp);
MinHeapFixup(minheap, minheap.size() - 1);
}
}
else//maxheap.size() < minheap.size()
{
int mid = minheap[0];
if (n >= mid)
{
minheap.push_back(n);
MinHeapFixup(minheap, minheap.size() - 1);
int tmp = minheap[0];
MinHeapDeleteNumber(minheap);
maxheap.push_back(tmp);
MaxHeapFixup(maxheap, maxheap.size() - 1);
}
else
{
maxheap.push_back(n);
MaxHeapFixup(maxheap, maxheap.size() - 1);
}
}
}
//找出中位数
int FindMiddleNumber()
{
if (maxheap.empty() && minheap.empty())
return 0;
if (maxheap.size() == minheap.size())
return (maxheap[0] + minheap[0]) / 2;
if (maxheap.size() < minheap.size())
return minheap[0];
else
return maxheap[0];
}