打了个板子,之前没初始化wa了两个点
也不知道为什么会错。。。。。
替罪羊树的主要思想就是将不平衡的树压成一个序列,然后暴力重构成一颗平衡的树.
这里的平衡指的是:对于某个 0.5<=alpha<=1 满足 size( lson(x) )<=alpha*size(x) 并且 size( rson(x) )<=alpha*size(x),即这个节点的两棵子树的 size 都不超过以该节点为根的子树的 size ,那么就称这个子树(或节点)是平衡的, alpha 最好不要选 0.5 ,容易T飞,一般选 0.75 就挺好的.
至于复杂度,虽说是重构,但复杂度并不高,压扁和重建都是递归操作,也就是像线段树一样的 log 级别,由于平衡的限制,插入,删除,即查询等操作也会控制在一个较低的级别,均摊下来替罪羊树的总复杂度是 O(logn) 的.
#include<cstdio>
#include<iostream>
#define ll long long
#define db double
const ll inf=1e9+7;
using namespace std;
const db al=0.75;
ll cur[110000],sum,root=0,tot=0,n;
struct ScapeGoatTree{
ll ch[110000][3];
ll key[110000],size[190099],f[110000];
void init(){
tot=2,root=1;
key[1]=-inf,size[1]=2,ch[1][1]=2;
key[2]=inf,size[2]=1,f[2]=1;
}
ll get(ll x){//查那个儿子
return ch[f[x]][1]==x;
}
bool blance(ll x){//平衡标准
return (db)size[x]*al>=size[ch[x][0]]&&(db)size[x]*al>=(db)size[ch[x][1]];
}
void dfs(ll x){//记下拍扁后序列
if(!x) return ;
dfs(ch[x][0]);
cur[++sum]=x;
dfs(ch[x][1]);
}
ll built(ll l,ll r){//建新树
if(l==r) {
size[cur[l]]=1;return l;
}
ll mid=(l+r)>>1,tmp=cur[mid];
f[ch[tmp][0]=built(l,mid)]=tmp;
f[ch[tmp][1]=built(mid+1,r)]=tmp;
size[tmp]=size[ch[tmp][0]]+size[ch[tmp][1]]+1;
return tmp;
}
void rebuilt(ll x){//重建树
sum=0;dfs(x);
ll fa=f[x],opt=get(x);
ll tmp=built(1,sum);
f[tmp]=fa;
ch[fa][opt]=tmp;
if(root==x) root=tmp;
}
void insert(ll x){//插入
if(!root){f[++tot]=ch[tot][0]=ch[tot][1]=0;size[tot]=1;key[tot]=x;root=tot;return;}
ll tmp=root,fa=0;
key[++tot]=x;size[tot]=1;
while(tmp){
size[tmp]++;
fa=tmp;
tmp=ch[tmp][key[tmp]<x];
if(!tmp){
f[tot]=fa;ch[fa][key[fa]<x]=tot;
break;
}
}
ll flag=0;
for(;tmp;tmp=f[tmp]) if(!blance(tmp)) flag=tmp;
if(flag) rebuilt(flag);
}
ll find(ll x){//查排名
ll tmp=root,ans=0;
while(tmp){
if(key[tmp]>=x) tmp=ch[tmp][0];
else ans+=size[ch[tmp][0]]+1,tmp=ch[tmp][1];
}
return ans;
}
ll findx(ll x){//差排名为x的数
ll tmp=root;
while(tmp){
if(size[ch[tmp][0]]==x-1) return tmp;
else if(size[ch[tmp][0]]>=x) tmp=ch[tmp][0];
else x-=size[ch[tmp][0]]+1,tmp=ch[tmp][1];
}
return tmp;
}
ll pre(ll x){//查前驱
ll tmp=root;ll ans=-inf;
while(tmp){
if(key[tmp]<x) ans=max(key[tmp],ans),tmp=ch[tmp][1];
else tmp=ch[tmp][0];
}
return ans;
}
ll next(ll x){//查后继
ll tmp=root;ll ans=inf;
while(tmp){
if(key[tmp]>x)ans=min(key[tmp],ans),tmp=ch[tmp][0];
else tmp=ch[tmp][1];
}
return ans;
}
ll getnum(ll x){//找编号
ll tmp=root;
while(tmp){
if(key[tmp]==x) return tmp;
else tmp=ch[tmp][key[tmp]<x];
}
}
void del(ll x){//删除
x=getnum(x);
if(ch[x][0]&&ch[x][1]){
ll tmp=ch[x][0];
while(ch[tmp][1]) tmp=ch[tmp][1];
key[x]=key[tmp];x=tmp;
}
ll son=ch[x][0]?ch[x][0]:ch[x][1];
ll opt=get(x);
ch[f[x]][opt]=son;
f[son]=f[x];
for(;x;x=f[x]) size[x]--;
if(x==root)root=son;
}
}s;
int main(){
scanf("%lld",&n);
s.init();
while(n--){
ll x;ll opt;
scanf("%lld%lld",&opt,&x);
if(opt==1) s.insert(x);
if(opt==2) s.del(x);
if(opt==3) printf("%lld\n",s.find(x));
if(opt==4) printf("%lld\n",s.key[s.findx(x+1)]);
if(opt==5) printf("%lld\n",s.pre(x));
if(opt==6) printf("%lld\n",s.next(x));
}
}
上面那个代码错了,当普通平衡树用吧
下面那个应该对了
#include<cstdio>
#include<iostream>
#define ll long long
#define db double
const ll inf=1e9+7;
using namespace std;
const db al=0.75;
ll cur[110000],sum,root=0,tot=0,n;
struct ScapeGoatTree{
ll ch[110000][3];
ll key[110000],size[190099],f[110000];
void init(){
tot=2,root=1;
key[1]=-inf,size[1]=2,ch[1][1]=2;
key[2]=inf,size[2]=1,f[2]=1;
}
ll get(ll x){
return ch[f[x]][1]==x;
}
bool blance(ll x){
return (db)size[x]*al>=(db)size[ch[x][0]]&&(db)size[x]*al>=(db)size[ch[x][1]];
}
void dfs(ll x){
if(!x) return ;
dfs(ch[x][0]);
cur[++sum]=x;
dfs(ch[x][1]);
}
ll built(ll l,ll r){
if(l>r) return 0;
ll mid=(l+r)>>1,tmp=cur[mid];
f[ch[tmp][0]=built(l,mid-1)]=tmp;
f[ch[tmp][1]=built(mid+1,r)]=tmp;
size[tmp]=size[ch[tmp][0]]+size[ch[tmp][1]]+1;
return tmp;
}
void rebuilt(ll x){
sum=0;dfs(x);
ll fa=f[x],opt=get(x);
ll tmp=built(1,sum);
f[tmp]=fa;
ch[fa][opt]=tmp;
if(root==x) root=tmp;
}
void insert(ll x){
if(!root){f[++tot]=ch[tot][0]=ch[tot][1]=0;size[tot]=1;key[tot]=x;root=tot;return;}
ll tmp=root,fa=0;
key[++tot]=x;size[tot]=1;
while(tmp){
size[tmp]++;
fa=tmp;
tmp=ch[tmp][key[tmp]<x];
if(!tmp){
f[tot]=fa;ch[fa][key[fa]<x]=tot;
tmp=tot;
break;
}
}
ll flag=0;
for(;tmp;tmp=f[tmp]) if(!blance(tmp)) flag=tmp;
if(flag) rebuilt(flag);
}
ll find(ll x){
ll tmp=root,ans=0;
while(tmp){
if(key[tmp]>=x) tmp=ch[tmp][0];
else ans+=size[ch[tmp][0]]+1,tmp=ch[tmp][1];
}
return ans;
}
ll findx(ll x){
ll tmp=root;
while(tmp){
if(size[ch[tmp][0]]==x-1) return tmp;
else if(size[ch[tmp][0]]>=x) tmp=ch[tmp][0];
else x-=size[ch[tmp][0]]+1,tmp=ch[tmp][1];
}
return tmp;
}
ll pre(ll x){
ll tmp=root;ll ans=-inf;
while(tmp){
if(key[tmp]<x) ans=max(key[tmp],ans),tmp=ch[tmp][1];
else tmp=ch[tmp][0];
}
return ans;
}
ll next(ll x){
ll tmp=root;ll ans=inf;
while(tmp){
if(key[tmp]>x)ans=min(key[tmp],ans),tmp=ch[tmp][0];
else tmp=ch[tmp][1];
}
return ans;
}
ll getnum(ll x){
ll tmp=root;
while(tmp){
if(key[tmp]==x) return tmp;
else tmp=ch[tmp][key[tmp]<x];
}
}
void del(ll x){
x=getnum(x);
if(ch[x][0]&&ch[x][1]){
ll tmp=ch[x][0];
while(ch[tmp][1]) tmp=ch[tmp][1];
key[x]=key[tmp];x=tmp;
}
ll son=ch[x][0]?ch[x][0]:ch[x][1];
ll opt=get(x);
ch[f[x]][opt]=son;
f[son]=f[x];
for(int i=x;i;i=f[i]) size[i]--;
if(x==root)root=son;
}
}s;
int main(){
scanf("%lld",&n);
s.init();
while(n--){
ll x;ll opt;
scanf("%lld%lld",&opt,&x);
if(opt==1) s.insert(x);
if(opt==2) s.del(x);
if(opt==3) printf("%lld\n",s.find(x));
if(opt==4) printf("%lld\n",s.key[s.findx(x+1)]);
if(opt==5) printf("%lld\n",s.pre(x));
if(opt==6) printf("%lld\n",s.next(x));
}
}