题目翻译:
模拟栈的基本操作(入和弹出),同时新加一项操作——输出栈内目前的中位数。
题解思路:
在更新的同时输出中位数,可以考虑使用对顶堆或者树状数组+二分搜索。
对顶堆:
本题如果直接用priority_queue的话,由于priority_queue不能直接删除特定元素,不是很方便,所以这里用multiset来表示大小根堆。
以下是通过小根堆和大根堆来找到中位数的思路:
使用两个堆:
大根堆(max-heap): 存储较小的一半数据,堆顶是这一部分的最大值。
小根堆(min-heap): 存储较大的一半数据,堆顶是这一部分的最小值。
平衡堆:当插入一个新元素时,如果它比大根堆的堆顶小,插入大根堆;否则插入小根堆。
插入后,如果两个堆的大小相差超过1,则将多余元素从较大的堆中移动到较小的堆中,以保持两个堆的平衡。
查找中位数:如果两个堆的大小相同,中位数是两个堆顶的平均值。
如果堆的大小不同,中位数是元素更多的那个堆的堆顶。
树状数组+二分搜索:
树状数组用于维护栈中每个元素的频率分布。getsum(mid)
函数返回的是栈中所有值小于等于 mid
的元素的数量(频率前缀和)。通过计算 getsum(mid)
,我们可以知道在当前栈中,有多少元素的值小于或等于 mid
。再结合二分搜索找到left==right的位置就是mid。
代码:
对顶堆:
#include <bits/stdc++.h>
using namespace std;
stack<int> s;
multiset<int> maxHeap, minHeap;
void pushNum(int num)
{
if (maxHeap.empty() || num <= (*prev(maxHeap.end())))
maxHeap.insert(num);
else
minHeap.insert(num);
if (maxHeap.size() > minHeap.size() + 1)
{ // 调整
minHeap.insert(*prev(maxHeap.end()));
maxHeap.erase((prev(maxHeap.end())));
}
else if (minHeap.size() > maxHeap.size())
{
maxHeap.insert(*minHeap.begin());
minHeap.erase(minHeap.begin());
}
}
void popNum(int num)
{
if (num <= (*prev(maxHeap.end())))
maxHeap.erase(maxHeap.find(num));
else
minHeap.erase(minHeap.find(num));
if (maxHeap.size() > minHeap.size() + 1)
{ // 调整
minHeap.insert(*prev(maxHeap.end()));
maxHeap.erase(prev(maxHeap.end()));
}
else if (minHeap.size() > maxHeap.size())
{
maxHeap.insert(*minHeap.begin());
minHeap.erase(minHeap.begin());
}
}
int main()
{
int n, key;
scanf("%d", &n);
for (int i = 0; i < n; i++)
{
char str[15];
scanf("%s", &str);
if (str[1] == 'o')
{
if (s.empty())
printf("Invalid\n");
else
{
printf("%d\n", s.top());
popNum(s.top());
s.pop();
}
}
else if (str[1] == 'u')
{
scanf("%d", &key);
pushNum(key);
s.push(key);
}
else
{
if (s.empty())
printf("Invalid\n");
else
{
printf("%d\n", (*prev(maxHeap.end())));
}
}
}
}
树状数组+二分搜索:
#include<bits/stdc++.h>
#define lowbit(i) ((i) & (-i))
const int maxn = 100010;
using namespace std;
int c[maxn];//树状数组
stack<int> s;
void update(int x, int v) {
for (int i = x; i < maxn; i += lowbit(i)) c[i] += v;
}
int getsum(int x) {
int sum = 0;
for (int i = x; i >= 1; i -= lowbit(i)) sum += c[i];
return sum;
}
void PeekMedian() {
int left = 1, right = maxn, mid, k = (s.size() + 1) / 2;
while (left < right) {
mid = (left + right) / 2;
if (getsum(mid) >= k) right = mid;
else left = mid + 1;
}
cout << left << endl;
}
int main() {
int n, temp;
cin >> n;
string str;
for (int i = 0; i < n; i++) {
cin >> str;
if (str[1] == 'u') {
cin >> temp;
s.push(temp);
update(temp, 1);
}
else if (str[1] == 'o') {
if (!s.empty()) {
update(s.top(), -1);
cout << s.top() << endl;
s.pop();
}
else cout << "Invalid" << endl;
}
else {
if (!s.empty()) PeekMedian();
else cout << "Invalid" << endl;
}
}
return 0;
}
坑点:
无