替罪羊树 模版题
原题地址:http://www.lydsy.com/JudgeOnline/problem.php?id=3224
题意:
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1. 插入x数
2. 删除x数(若有多个相同的数,因只删除一个)
3. 查询x数的排名(若有多个相同的数,因输出最小的排名)
4. 查询排名为x的数
5. 求x的前驱(前驱定义为小于x,且最大的数)
6. 求x的后继(后继定义为大于x,且最小的数)
数据范围
1.n的数据范围:n<=100000
2.每个数的数据范围:[-2e9,2e9]
题解:
学习一发替罪羊树。
opt1:直接插入,找插入点到root的链上高度最高的不平衡点重构。
opt2:把他赋为它前驱的值,之后删前驱,前驱最多只有一个儿子就很好删。
opt3~opt6和splay差不多。
代码:
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=100005;
const double Alpha=0.7;
const int inf=2147183600;
struct node
{
int fa,size,ch[2],val;
}tr[N];
int n,root,tail=0,V[N],top=0;
void init()
{
tail=2; root=1;
tr[1].val=-inf; tr[1].size=2;
tr[2].val=inf; tr[2].size=1;
tr[1].ch[1]=2; tr[1].fa=0; tr[2].fa=1;
}
bool balance(int x)
{
return (double)tr[x].size*Alpha>=(double)tr[tr[x].ch[0]].size
&&(double)tr[x].size*Alpha>=(double)tr[tr[x].ch[1]].size;
}
void getall(int x)
{
if(tr[x].ch[0]) getall(tr[x].ch[0]);
V[++top]=x;
if(tr[x].ch[1]) getall(tr[x].ch[1]);
}
int build(int lf,int rg)
{
if(lf>rg)return 0;
int mid=(lf+rg)>>1;
int nd=V[mid];
tr[nd].ch[0]=build(lf,mid-1); if(tr[nd].ch[0]) tr[tr[nd].ch[0]].fa=nd;
tr[nd].ch[1]=build(mid+1,rg); if(tr[nd].ch[1]) tr[tr[nd].ch[1]].fa=nd;
tr[nd].size=tr[tr[nd].ch[0]].size+tr[tr[nd].ch[1]].size+1;
return nd;
}
void rebuild(int x)
{
int fa=tr[x].fa; top=0; int opt;
if(fa) opt=(tr[fa].ch[1]==x);
getall(x);
int nd=build(1,top);
if(x==root) root=nd;
else tr[fa].ch[opt]=nd;
tr[nd].fa=fa;
}
void insert(int val)
{
int tmp=root; int f=0;
int nd=++tail; tr[nd].val=val; tr[nd].size=1;
while(1)
{
tr[tmp].size++;
int opt=(val<=tr[tmp].val)? 0:1;
f=tmp; tmp=tr[tmp].ch[opt];
if(!tmp)
{
tr[nd].fa=f; tr[f].ch[opt]=nd; break;
}
}
int pos=0;
for(int i=nd;i;i=tr[i].fa) {if(!balance(i)) pos=i;}
if(pos) rebuild(pos);
}
int getpos(int val)
{
int tmp=root;
while(1)
{
if(tr[tmp].val==val) return tmp;
else tmp=tr[tmp].ch[(val>tr[tmp].val)];
}
return 0;
}
void del(int x)
{
int tmp=x;
if(tr[x].ch[0]&&tr[x].ch[1])
{
tmp=tr[x].ch[0];
while(tr[tmp].ch[1]) tmp=tr[tmp].ch[1];
tr[x].val=tr[tmp].val;
x=tmp;
}
int fa=tr[tmp].fa;
int son=tr[tmp].ch[0]? tr[tmp].ch[0]:tr[tmp].ch[1]; int opt= tr[tr[tmp].fa].ch[0]==tmp? 0:1;
if(son) tr[son].fa=fa; if(fa) tr[fa].ch[opt]=son;
while(fa) {tr[fa].size--;fa=tr[fa].fa;}
if(tmp==root) root=son;
}
int getrank(int val)
{
int tmp=root; int ret=0;
while(tmp)
{
if(tr[tmp].val<val) {ret+=tr[tr[tmp].ch[0]].size+1; tmp=tr[tmp].ch[1];}
else tmp=tr[tmp].ch[0];
}
return ret;
}
int getkth(int k)
{
int tmp=root;
while(1)
{
if(tr[tr[tmp].ch[0]].size+1==k) return tr[tmp].val;
else if(tr[tr[tmp].ch[0]].size+1>k) tmp=tr[tmp].ch[0];
else {k-=tr[tr[tmp].ch[0]].size+1; tmp=tr[tmp].ch[1];}
}
}
int getpre(int x)
{
int ans=-inf; int tmp=root;
while(tmp)
{
if(tr[tmp].val>=x) tmp=tr[tmp].ch[0];
else {ans=max(ans,tr[tmp].val); tmp=tr[tmp].ch[1]; }
}
return ans;
}
int getnxt(int x)
{
int ans=inf; int tmp=root;
while(tmp)
{
if(tr[tmp].val<=x) tmp=tr[tmp].ch[1];
else {ans=min(ans,tr[tmp].val); tmp=tr[tmp].ch[0]; }
}
return ans;
}
int main()
{
scanf("%d",&n);
init();
while(n--)
{
int opt,x;
scanf("%d%d",&opt,&x);
if(opt==1) insert(x);
else if(opt==2) del(getpos(x));
else if(opt==3) printf("%d\n",getrank(x));
else if(opt==4) printf("%d\n",getkth(x+1));
else if(opt==5) printf("%d\n",getpre(x));
else if(opt==6) printf("%d\n",getnxt(x));
}
return 0;
}