1. 首先看题意很容易就得出暴力解法, 即出/入栈正常模拟, 找中值就通过对当前栈内的数进行排序, 再输出中值即可, 复杂度O(n^n^logn)。
代码如下:
#include <bits/stdc++.h>
using namespace std;
#define endl "\n"
#define debug cout<<"debug"<<endl
mt19937 rd(time(0));
typedef long long ll;
typedef long double ld;
typedef pair<int, int> PII;
const double eps = 1e-8;
const double PI = 3.14159265358979323;
const int N = 2e5+10, M = 2*N, mod = 1e9+7;
const int INF = 0x3f3f3f3f;
int n;
stack<int> stk;
void solve()
{
cin>>n;
vector<int> v;
while( n -- )
{
string s; cin>>s;
if(s == "Pop")
{
if(stk.size())
{
cout<<stk.top()<<endl; stk.pop();
v.erase(v.end()-1);
}else cout<<"Invalid"<<endl;
}else if(s == "PeekMedian")
{
if(stk.size())
{
vector<int> tmp = v;
sort(tmp.begin(), tmp.end());
cout<<tmp[(tmp.size()-1)/2]<<endl;
}else cout<<"Invalid"<<endl;
}else{
int x; cin>>x;
stk.push(x);
v.push_back(x);
}
}
}
int main()
{
ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
int T;
T = 1;
// cin>>T;
while(T -- )
{
solve();
}
return 0;
}
很明显, 这样的做法铁T, 但是能骗到17分, 还可以
2. 优化, 考虑一个cnt数组, cnt[x]表示栈内x出现的次数, 对从1开始的cnt数组, 依次求和, 当和sum>=中位即(n+1)/2时, 当前位就是中值, 时间复杂度 O(n^n)
#include <bits/stdc++.h>
using namespace std;
#define endl "\n"
#define debug cout<<"debug"<<endl
mt19937 rd(time(0));
typedef long long ll;
typedef long double ld;
typedef pair<int, int> PII;
const double eps = 1e-8;
const double PI = 3.14159265358979323;
const int N = 2e5+10, M = 2*N, mod = 1e9+7;
const int INF = 0x3f3f3f3f;
int n;
stack<int> stk;
int cnt[N];
void solve()
{
cin>>n;
vector<int> v;
while( n -- )
{
string s; cin>>s;
if(s == "Pop")
{
if(stk.size())
{
cnt[stk.top()]--;
cout<<stk.top()<<endl; stk.pop();
}else cout<<"Invalid"<<endl;
}else if(s == "PeekMedian")
{
if(stk.size())
{
ll sum = 0;
for(int i = 1; i<=1e5; i++)
{
sum += cnt[i];
if(sum>=(stk.size()+1)/2)
{
cout<<i<<endl;
break;
}
}
}else cout<<"Invalid"<<endl;
}else{
int x; cin>>x;
stk.push(x);
cnt[x]++;
}
}
}
int main()
{
ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
int T;
T = 1;
// cin>>T;
while(T -- )
{
solve();
}
return 0;
}
漂亮, 又多骗了5分hhh
3. 最终优化, 考虑二分, 我们要找的是一个数, 这个数和前面所有数的出现次数之和, 来判断这个次数和是不是到了中位, 这个过程是可以二分的, 但问题是如何快速的获取前面所有数出现的次数之和, 我一开始想的是树状数组, 但是脑子抽了不太会写, 选择用线段树的解法, 线段树维护区间数字出现的次数和, 每次入栈就对应区间+1反之-1, 查询中值时, 就通过查询线段树来快速的获取前面所有数出现的次数和, 时间复杂度O(n^logn^logn)
代码如下:
#include <bits/stdc++.h>
using namespace std;
#define endl "\n"
#define debug cout<<"debug"<<endl
mt19937 rd(time(0));
typedef long long ll;
typedef long double ld;
typedef pair<int, int> PII;
const double eps = 1e-8;
const double PI = 3.14159265358979323;
const int N = 2e5+10, M = 2*N, mod = 1e9+7;
const int INF = 0x3f3f3f3f;
int n;
stack<int> stk;
int cnt[N];
struct Node{
int l, r, sum;
}tr[N<<2];
void pushup(int u)
{
tr[u].sum = tr[u<<1].sum + tr[u<<1|1].sum;
}
void build(int u, int l, int r)
{
tr[u] = {l, r};
if(l == r) return;
int mid = l + r >> 1;
build(u<<1, l, mid), build(u<<1|1, mid+1, r);
}
int query(int u, int l, int r)
{
if(tr[u].l>=l && tr[u].r<=r) return tr[u].sum;
int mid = tr[u].l + tr[u].r >> 1;
int res = 0;
if(l<=mid) res = query(u<<1, l, r);
if(r>mid) res += query(u<<1|1, l, r);
return res;
}
void modify(int u, int x, int k)
{
if(tr[u].l == x && tr[u].r == x)
{
if(k == 1) tr[u].sum++;
else tr[u].sum = max(0, tr[u].sum-1);
return;
}
int mid = tr[u].l + tr[u].r >> 1;
if(x<=mid) modify(u<<1, x, k);
else modify(u<<1|1, x, k);
pushup(u);
}
void solve()
{
cin>>n;
vector<int> v;
build(1, 1, 1e5);
while( n -- )
{
string s; cin>>s;
if(s == "Pop")
{
if(stk.size())
{
modify(1, stk.top(), 0);
cout<<stk.top()<<endl; stk.pop();
}else cout<<"Invalid"<<endl;
}else if(s == "PeekMedian")
{
if(stk.size())
{
int l = 1, r = 1e5;
int x = (stk.size()+1)/2;
while(l<r)
{
int mid = l+r >> 1;
if(query(1, 1, mid)>=x) r = mid;
else l = mid + 1;
}
cout<<l<<endl;
}else cout<<"Invalid"<<endl;
}else{
int x; cin>>x;
stk.push(x);
modify(1, x, 1);
}
}
}
int main()
{
ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
int T;
T = 1;
// cin>>T;
while(T -- )
{
solve();
}
return 0;
}
然后就ac啦, 但显然还有更好写的做法
下面给出一开始没想到的树状数组优化, 思维同线段树但实现起来更简单:
#include <bits/stdc++.h>
using namespace std;
#define endl "\n"
#define debug cout<<"debug"<<endl
mt19937 rd(time(0));
typedef long long ll;
typedef long double ld;
typedef pair<int, int> PII;
const double eps = 1e-8;
const double PI = 3.14159265358979323;
const int N = 2e5+10, M = 2*N, mod = 1e9+7;
const int INF = 0x3f3f3f3f;
int n;
stack<int> stk;
int tr[N];
int lowbit(int x)
{
return x & -x;
}
void add(int x, int c)
{
for( ; x<=1e5; x += lowbit(x)) tr[x] += c;
}
int ask(int x)
{
int res = 0;
for( ; x; x -= lowbit(x)) res += tr[x];
return res;
}
void solve()
{
cin>>n;
vector<int> v;
while( n -- )
{
string s; cin>>s;
if(s == "Pop")
{
if(stk.size())
{
add(stk.top(), -1);
cout<<stk.top()<<endl; stk.pop();
}else cout<<"Invalid"<<endl;
}else if(s == "PeekMedian")
{
if(stk.size())
{
int l = 1, r = 1e5;
int x = (stk.size()+1)/2;
while(l<r)
{
int mid = l+r >> 1;
if(ask(mid)>=x) r = mid;
else l = mid + 1;
}
cout<<l<<endl;
}else cout<<"Invalid"<<endl;
}else{
int x; cin>>x;
stk.push(x);
add(x, 1);
}
}
}
int main()
{
ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
int T;
T = 1;
// cin>>T;
while(T -- )
{
solve();
}
return 0;
}