#include <iostream>
#include <cstdio>
#include <set>
#include <stack>
#include <cstring>
#include <map>
#include <string>
using namespace std;
#define N 100001
void adjust(int &mid, int d, int push, int oldsize, int count[], int index[]) {
if(push) {
if(!oldsize) {
mid = d;
count[d] = 1;
index[d] = 1;
}
else {
if(d == mid) {
count[mid] ++;
if(oldsize%2) {
// do nothing;
}
else {
index[mid] ++;
}
}
else if(d < mid) {
count[d] ++;
if(oldsize%2) {
if(index[mid] > 1) {
index[mid] --;
}
else {
while(!count[--mid]) ;
index[mid] = count[mid];
}
}
else {
// do nothing;
}
}
else if(d > mid) {
count[d] ++;
if(oldsize%2) {
// do nothing;
}
else {
if(index[mid] < count[mid]) {
index[mid] ++;
}
else {
while(!count[++mid]) ;
index[mid] = 1;
}
}
}
}
}
else {
if(oldsize > 1) {
if(mid == d) {
if(oldsize%2) {
if(index[mid] > 1) {
index[mid] --;
count[mid] --;
}
else {
count[mid] --;
while(!count[--mid]) ;
index[mid] = count[mid];
}
}
else {
if(index[mid] < count[mid]) {
count[mid] --;
}
else {
count[mid] --;
while(!count[++mid]) ;
index[mid] = 1;
}
}
}
else if(d < mid) {
count[d] --;
if(oldsize%2) {
// do nothing
}
else {
if(index[mid] < count[mid])
index[mid] ++;
else {
while(!count[++mid]) ;
index[mid] = 1;
}
}
}
else if(d > mid) {
count[d] --;
if(oldsize%2) {
if(index[mid] > 1)
index[mid] --;
else {
while(!count[--mid]) ;
index[mid] = count[mid];
}
}
else {
// do nothing
}
}
}
else {
count[d] --;
}
}
}
int main(int argc, char **argv) {
int n;
cin >> n;
char cmd[20];
stack<int> st;
int count[N] = {0};
int index[N] = {0};
int mid = 0;
for(int i = 0; i < n; i ++) {
scanf("%s", cmd);
if(!strcmp(cmd, "Pop")) {
if(st.size()) {
int d = st.top();
printf("%d\n", d);
adjust(mid, d, 0, st.size(), count, index);
st.pop();
}
else {
printf("Invalid\n");
}
}
else if(!strcmp(cmd, "Push")) {
int d;
scanf("%d", &d);
adjust(mid, d, 1, st.size(), count, index);
st.push(d);
}
else {
if(st.size())
printf("%d\n", mid);
else
printf("Invalid\n");
}
}
return 0;
}
下面是按照链接中提供的优化算法完成的代码,算法本质相同,但实现起来却方便许多。
#include <iostream>
#include <cstdio>
#include <set>
#include <cstring>
#include <stack>
using namespace std;
void adjust(int cmd, int n, int &mid, bool &exist) {
static multiset<int> smaller, bigger;
if(cmd == 0) {
if(n < mid) {
smaller.erase(smaller.find(n));
if(smaller.size() - bigger.size() <= -2) {
smaller.insert(mid);
mid = *(bigger.begin());
bigger.erase(bigger.find(mid));
}
}
else if(n > mid) {
bigger.erase(bigger.find(n));
if(smaller.size() > bigger.size()) {
bigger.insert(mid);
mid = *(smaller.rbegin());
smaller.erase(smaller.find(mid));
}
}
else {
if(smaller.empty() && bigger.empty()) {
exist = false;
}
else if(smaller.empty() && bigger.size()) {
mid = *(bigger.begin());
bigger.erase(bigger.find(mid));
}
else {
if(smaller.size() == bigger.size()) {
mid = *(smaller.rbegin());
smaller.erase(smaller.find(mid));
}
else {
mid = *(bigger.begin());
bigger.erase(bigger.find(mid));
}
}
}
}
else {
if(!exist) {
mid = n;
exist = true;
}
else {
if(n > mid) {
bigger.insert(n);
if(bigger.size() - smaller.size() >= 2) {
smaller.insert(mid);
mid = *(bigger.begin());
bigger.erase(bigger.find(mid));
}
}
else if(n < mid) {
smaller.insert(n);
if(smaller.size() > bigger.size()) {
bigger.insert(mid);
mid = *(smaller.rbegin());
smaller.erase(smaller.find(mid));
}
}
else {
if(smaller.size() < bigger.size()) {
smaller.insert(n);
}
else {
bigger.insert(n);
}
}
}
}
}
int main()
{
int n;
cin >> n;
stack<int> s;
int mid;
bool exist = false;
for(int i = 0; i < n; i ++) {
char cmd[20];
scanf("%s", cmd);
if(!strcmp(cmd, "Pop")) {
if(exist) {
printf("%d\n", s.top());
adjust(0, s.top(), mid, exist);
s.pop();
}
else {
printf("Invalid\n");
}
}
else if(!strcmp(cmd, "PeekMedian")) {
if(exist) {
printf("%d\n", mid);
}
else {
printf("Invalid\n");
}
}
else {
int d;
scanf("%d", &d);
s.push(d);
adjust(1, d, mid, exist);
}
}
return 0;
}