觉得模块化起来挺好看的,也不容易错^_^
bzoj 3224 普通平衡树
#include <cstdio>
#define lf ch[x][0]
#define rg ch[x][1]
#define rep(i,j,k) for (i=j;i<=k;i++)
const int inf=2e9+1,N=1e5+5;
using namespace std;
int n,i,od,x,o;
struct bst{
int root,pn,tp,cs,stk[N],sz[N],w[N],c[N],ch[N][2],fa[N];
int sd(int x) {return ch[fa[x]][1]==x;}
inline void updata(int x) { sz[x]=sz[lf]+sz[rg]+c[x]; }
inline int newnode(int W) {
int p;
if (tp>0) p=stk[tp--];
else p=++pn;
sz[p]=1; w[p]=W; c[p]=1; ch[p][0]=ch[p][1]=0; fa[p]=0;
return p;
}
void init() {
tp=pn=cs=0; newnode(-inf); newnode(inf); root=1;
sz[1]=2; ch[1][1]=2; sz[2]=1; fa[2]=1;
}
void rotate(int x)
{
int f=fa[x],gf=fa[f],gs=sd(f),xs=sd(x);
fa[f]=x; ch[gf][gs]=x; fa[x]=gf; ch[f][xs]=ch[x][xs^1];
fa[ch[x][xs^1]]=f; ch[x][xs^1]=f;
updata(f); updata(x);
}
void splay(int x,int y)
{
int sd1,sd2;
if (x==y) return;
while (fa[x]!=y)
{
if (fa[fa[x]]==y) rotate(x);
else {
sd1=sd(x); sd2=sd(fa[x]);
if (sd1^sd2) rotate(x); else rotate(fa[x]);
rotate(x);
}
}
if (!fa[x]) root=x;
}
void insert(int W)
{
int x=root,lst=0,lsd=0;
while (1) {
if (w[x]==W) { c[x]++; sz[x]++; splay(x,0); return; }
if (!x) { x=newnode(W); fa[x]=lst; ch[lst][lsd]=x; splay(x,0); return; }
if (w[x]<W) { lst=x; lsd=1; x=rg; }
else { lst=x; lsd=0; x=lf; }
}
}
int find_kth(int rt,int k) //ÎÞ·¨ÅжÏÕÒ²»µ½µÄÇé¿öŶ
{
int x=rt;
while (1)
{
if (sz[lf]<k && sz[lf]+c[x]>=k) {
splay(x,fa[rt]); return x;
}
else if (sz[lf]>=k) x=lf;
else {k-=sz[lf]+c[x]; x=rg;}
}
}
void del(int x)
{
int suc;
c[x]--; splay(x,0);
if (c[x]>0) return;
suc=find_kth(rg,1);
root=suc; ch[suc][0]=lf; fa[lf]=suc; fa[suc]=0; stk[++tp]=x;
updata(suc);
}
int find_l(int W)
{
int x=root,ll=0;
while (1) {
if (!x) break;
if (w[x]==W) {ll=x; break;}
if (w[x]<W) ll=x,x=rg; else x=lf;
}
splay(ll,0); return ll;
}
int find_r(int W)
{
int x=root,rr=0;
while (1) {
if (!x) break;
if (w[x]==W) { rr=x; break;}
if (w[x]<W) x=rg; else rr=x,x=lf;
}
splay(rr,0); return rr;
}
int find_rank(int W) {
int x=find_l(W);
if (w[x]==W) return sz[lf];
return c[x]+sz[lf];
}
int find_pre(int W) {
int x=find_l(W),y;
if (w[x]==W) y=find_kth(lf,sz[lf]);
else y=x;
return w[y];
}
int find_suc(int W) {
int x=find_r(W),y;
if (w[x]==W) y=find_kth(rg,1);
else y=x;
return w[y];
}
void dfs(int x)
{
printf("id=%d w=%d sz=%d lc=%d rc=%d fa=%d\n",x,w[x],sz[x],ch[x][0],ch[x][1],fa[x]);
if (ch[x][0]) dfs(ch[x][0]);
if (ch[x][1]) dfs(ch[x][1]);
}
void outit() {
cs++; printf("Case #%d:\n",cs);
dfs(root);
printf("\n");
}
}tr;
void read(int &ret)
{
char ch; int sgn=1; ret=0;
for (ch=getchar();ch<'0' || ch>'9';ch=getchar()) if (ch=='-') sgn=-1;
for (;ch>='0' && ch<='9';ch=getchar()) ret=ret*10+ch-'0';
ret*=sgn;
}
int main()
{
// freopen("bst.in","r",stdin);
// freopen("bst.out","w",stdout);
tr.init();
read(n);
rep(i,1,n)
{
read(od); read(x);
if (od==1) tr.insert(x);
if (od==2) tr.del(tr.find_l(x));
if (od==3) o=tr.find_rank(x);
if (od==4) o=tr.w[tr.find_kth(tr.root,x+1)];
if (od==5) o=tr.find_pre(x);
if (od==6) o=tr.find_suc(x);
if (od>2) printf("%d\n",o);
// tr.outit();
}
return 0;
}