Splay模板
模板对应题目:BZOJ3224
#include <iostream>
#include <cstdio>
#include <algorithm>
const int MAX=100010;
using namespace std;
struct Splay {
int ch[MAX][2],f[MAX],root,cnt;
int s[MAX],c[MAX],a[MAX];
int grow (int k,int fa) {
int x=++cnt;
a[x]=k,s[x]=c[x]=1;
ch[x][0]=ch[x][1]=0,f[x]=fa;
return x;
}
void update (int x) {
s[x]=s[ch[x][0]]+s[ch[x][1]]+c[x];
}
void rotate (int x) {
int y=f[x],opt;
if (ch[f[x]][0]==x) opt=0;
else opt=1;
ch[y][opt]=ch[x][!opt];
if (ch[x][!opt]) f[ch[x][!opt]]=y;
f[x]=f[y];
if (root==y) root=x;
else if (ch[f[y]][0]==y) ch[f[y]][0]=x;
else ch[f[y]][1]=x;
f[y]=x,ch[x][!opt]=y;
update(y),update(x);
}
void splay (int x,int to=0) {
while (f[x]!=to) {
if (f[f[x]]==to) rotate(x);
else if ((ch[f[f[x]]][0]==f[x])
==(ch[f[x]][0]==x))
rotate(f[x]),rotate(x);
else rotate(x),rotate(x);
}
}
void insert (int k) {
if (root==0) root=grow(k,0);
else {
int x=root;
while (x) {
if (k==a[x]) { c[x]++;break; }
else {
int &y=(k<a[x]?ch[x][0]:ch[x][1]);
if (!y)
{ x=y=grow(k,x);break; }
else x=y;
}
}
splay(x);
}
}
void erase (int k) {
int x=root;
while (x) {
if (k==a[x]) { c[x]--;break; }
else x=ch[x][!(k<a[x])];
}
splay(x);
if (!c[x]) {
int p=ch[x][0];
while (ch[p][1]) p=ch[p][1];
if (p) {
splay(p,root);
ch[p][1]=ch[root][1];
f[f[ch[root][1]]=p]=0;
update(root=p);
}
else {
if (ch[root][1]) f[ch[root][1]]=0,root=ch[root][1];
else root=cnt=0;//注意:如果写了回收结点,那么这里cnt一定不能置0
}
}
}
int find_by_rank (int k) {
int x=root;
while (x) {
if (k>=s[ch[x][0]]+1&&k<=s[ch[x][0]]+c[x])
return a[x];
else if (k<s[ch[x][0]]+1) x=ch[x][0];
else k-=s[ch[x][0]]+c[x],x=ch[x][1];
}
}
int find_by_val (int k) {
int x=root,ans=0;
while (x) {
if (k==a[x]) return ans+s[ch[x][0]]+1;
else if (k<a[x]) x=ch[x][0];
else ans+=s[ch[x][0]]+c[x],x=ch[x][1];
}
return ans;
}
int prev (int k) {
int x=root,ans;
while (x) {
if (k>a[x]) ans=a[x],x=ch[x][1];
else x=ch[x][0];
}
return ans;
}
int next (int k) {
int x=root,ans;
while (x) {
if (k<a[x]) ans=a[x],x=ch[x][0];
else x=ch[x][1];
}
return ans;
}
Splay () { root=cnt=0; }
}Sp;
int n,i,opt,x;
int read () {
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9') {
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9') {
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
int main () {
n=read();
for (int i=1;i<=n;i++) {
opt=read();x=read();
if (opt==1) Sp.insert(x);
if (opt==2) Sp.erase(x);
if (opt==3) printf("%d\n",Sp.find_by_val(x));
if (opt==4) printf("%d\n",Sp.find_by_rank(x));
if (opt==5) printf("%d\n",Sp.prev(x));
if (opt==6) printf("%d\n",Sp.next(x));
}
return 0;
}