数组实现的SB树(而且竟然没有结构体)
功能就和普通的二茬查找树一样:
insert: 插入元素
erase: 删除元素
pre: 查询一个元素的前驱
suc: 查询一个元素的后继
select: 返回第几小的元素的值
rank: 返回一个值的位置
除了maintain全部非递归实现,并且有assert(似乎并没有什么用)
maintain写得不是很高效,
并且缺储存平衡树中元素个数的变量(size表示总共插入过的元素个数),然而这并不怎么难搞
root: 树的根(并不一定是1)
这其实是洛谷的模板题
#include<stdio.h>
#include<assert.h>
#define N 100005
int root,val[N],lch[N],rch[N],w[N],size;
void rrot(int &t) {
int x=lch[t]; lch[t]=rch[x]; rch[x]=t;
w[x]=w[t]; w[t]=w[lch[t]]+w[rch[t]]+1;
t=x;
}
void lrot(int &t) {
int x=rch[t]; rch[t]=lch[x]; lch[x]=t;
w[x]=w[t]; w[t]=w[lch[t]]+w[rch[t]]+1;
t=x;
}
void mant(int &t) {
int fl=0,fr=0;
if(w[lch[lch[t]]]>w[rch[t]]) fr=1, rrot(t);
else if(w[rch[lch[t]]]>w[rch[t]]) fl=fr=1, lrot(lch[t]), rrot(t);
else if(w[rch[rch[t]]]>w[lch[t]]) fl=1, lrot(t);
else if(w[lch[rch[t]]]>w[lch[t]]) fl=fr=1, rrot(rch[t]), lrot(t);
if(fl) mant(lch[t]); if(fr) mant(rch[t]); if(fl||fr) mant(t);
}
int select(int k) {
int t=root;
while(k!=w[lch[t]]+1) {
assert(t);
if(k<=w[lch[t]]) t=lch[t];
else k-=w[lch[t]]+1, t=rch[t];
}
return val[t];
}
int rank(int v) {
int t=root,ret=0,flag=0;
while(t) {
if(v==val[t]) return ret+w[lch[t]]+1;
else {
if(v<val[t]) t=lch[t];
else ret+=w[lch[t]]+1, t=rch[t];
}
}
assert(0);
}
int pre(int v) {
int t=root,ret=0;
while(1) {
if(v<=val[t]) {
if(lch[t]) t=lch[t];
else break;
}
else {
ret=t;
if(rch[t]) t=rch[t];
else break;
}
}
assert(ret);
return val[ret];
}
int suc(int v) {
int t=root, ret=0;
while(1) {
if(v>=val[t]) {
if(rch[t]) t=rch[t];
else break;
}
else {
ret=t;
if(lch[t]) t=lch[t];
else break;
}
}
assert(ret);
return val[ret];
}
void insert(int v) {
if(!root) {root=++size; val[size]=v; w[size]=1; return; }
int t=root,*x;
while(w[t]++) {
if(v<val[t]) {
if(lch[t]) t=lch[t];
else {x=&lch[t]; break; }
}
else {
if(rch[t]) t=rch[t];
else {x=&rch[t]; break; }
}
}
*x=++size; val[size]=v; w[size]=1;
mant(root);
}
void erase(int v) {
int *x=&root;
while(w[*x]--) {
if(v<val[*x]) {x=&lch[*x]; continue; }
if(v>val[*x]) {x=&rch[*x]; continue; }
break;
}
if(!lch[*x]) {*x=rch[*x]; return; }
if(!rch[*x]) {*x=lch[*x]; return; }
int *y=&rch[*x];
while(w[*y]--,lch[*y]) y=&lch[*y];
val[*x]=val[*y]; *y=rch[*y];
}
int main() {
int n,x,y;
scanf("%d",&n);
while(n--) {
scanf("%d%d",&x,&y);
switch(x) {
case 1: insert(y); break;
case 2: erase(y); break;
case 3: printf("%d\n",rank(y)); break;
case 4: printf("%d\n",select(y)); break;
case 5: printf("%d\n",pre(y)); break;
case 6: printf("%d\n",suc(y)); break;
}
}
}
简化代码后的SBT
没有assert了,不过明显短了很多
#include<stdio.h>
#define N 100005
int root,val[N],lch[N],rch[N],w[N],size,cnt;
void rrot(int &t) {
int x=lch[t]; lch[t]=rch[x]; rch[x]=t;
w[x]=w[t]; w[t]=w[lch[t]]+w[rch[t]]+1;
t=x;
}
void lrot(int &t) {
int x=rch[t]; rch[t]=lch[x]; lch[x]=t;
w[x]=w[t]; w[t]=w[lch[t]]+w[rch[t]]+1;
t=x;
}
void mant(int &t) {
int fl=0,fr=0;
if(w[lch[lch[t]]]>w[rch[t]]) fr=1, rrot(t);
else if(w[rch[lch[t]]]>w[rch[t]]) fl=fr=1, lrot(lch[t]), rrot(t);
else if(w[rch[rch[t]]]>w[lch[t]]) fl=1, lrot(t);
else if(w[lch[rch[t]]]>w[lch[t]]) fl=fr=1, rrot(rch[t]), lrot(t);
if(fl) mant(lch[t]); if(fr) mant(rch[t]); if(fl||fr) mant(t);
}
void insert(int v) {
size++; int *x=&root;
while(w[*x]++) {
if(v<val[*x]) x=&lch[*x];
else x=&rch[*x];
}
*x=++cnt; val[cnt]=v; w[cnt]=1;
w[0]=0; mant(root);
}
void erase(int v) {
size--; int *x=&root;
while(w[*x]--) {
if(v<val[*x]) {x=&lch[*x]; continue; }
if(v>val[*x]) {x=&rch[*x]; continue; }
break;
}
if(!lch[*x]) {*x=rch[*x]; return; }
if(!rch[*x]) {*x=lch[*x]; return; }
int *y=&rch[*x];
while(w[*y]--,lch[*y]) y=&lch[*y];
val[*x]=val[*y]; *y=rch[*y];
}
int select(int k) {
int t=root;
while(k!=w[lch[t]]+1) {
if(k<=w[lch[t]]) t=lch[t];
else k-=w[lch[t]]+1, t=rch[t];
}
return val[t];
}
int rank(int v) {
int t=root,ret=0;
while(t) {
if(v<=val[t]) t=lch[t];
else ret+=w[lch[t]]+1, t=rch[t];
}
return ret+1;
}
int pre(int v) {
int t=root,ret=0;
while(t) {
if(v<=val[t]) t=lch[t];
else ret=t, t=rch[t];
}
return val[ret];
}
int suc(int v) {
int t=root,ret=0;
while(t) {
if(v>=val[t]) t=rch[t];
else ret=t, t=lch[t];
}
return val[ret];
}
int main() {
int n,x,y;
scanf("%d",&n);
while(n--) {
scanf("%d%d",&x,&y);
switch(x) {
case 1: insert(y); break;
case 2: erase(y); break;
case 3: printf("%d\n",rank(y)); break;
case 4: printf("%d\n",select(y)); break;
case 5: printf("%d\n",pre(y)); break;
case 6: printf("%d\n",suc(y)); break;
}
}
}
然而 展开之后...
200行
#include<stdio.h>
#define N 100005
int root,val[N],lch[N],rch[N],w[N],size,cnt;
void rrot(int &t)
{
int x=lch[t];
lch[t]=rch[x];
rch[x]=t;
w[x]=w[t];
w[t]=w[lch[t]]+w[rch[t]]+1;
t=x;
}
void lrot(int &t)
{
int x=rch[t];
rch[t]=lch[x];
lch[x]=t;
w[x]=w[t];
w[t]=w[lch[t]]+w[rch[t]]+1;
t=x;
}
void mant(int &t)
{
int fl=0,fr=0;
if(w[lch[lch[t]]]>w[rch[t]])
{
fr=1;
rrot(t);
}
else if(w[rch[lch[t]]]>w[rch[t]])
{
fl=fr=1;
lrot(lch[t]);
rrot(t);
}
else if(w[rch[rch[t]]]>w[lch[t]])
{
fl=1;
lrot(t);
}
else if(w[lch[rch[t]]]>w[lch[t]])
{
fl=fr=1;
rrot(rch[t]);
lrot(t);
}
if(fl)
{
mant(lch[t]);
}
if(fr)
{
mant(rch[t]);
}
if(fl||fr)
{
mant(t);
}
}
void insert(int v)
{
size++;
int *x=&root;
while(w[*x]++)
{
if(v<val[*x])
{
x=&lch[*x];
}
else
{
x=&rch[*x];
}
}
*x=++cnt;
val[cnt]=v;
w[cnt]=1;
w[0]=0;
mant(root);
}
void erase(int v)
{
size--;
int *x=&root;
while(w[*x]--)
{
if(v<val[*x])
{
x=&lch[*x];
continue;
}
if(v>val[*x])
{
x=&rch[*x];
continue;
}
break;
}
if(!lch[*x])
{
*x=rch[*x];
return;
}
if(!rch[*x])
{
*x=lch[*x];
return;
}
int *y=&rch[*x];
while(w[*y]--,lch[*y])
{
y=&lch[*y];
}
val[*x]=val[*y];
*y=rch[*y];
}
int select(int k)
{
int t=root;
while(k!=w[lch[t]]+1)
{
if(k<=w[lch[t]])
{
t=lch[t];
}
else
{
k-=w[lch[t]]+1;
t=rch[t];
}
}
return val[t];
}
int rank(int v)
{
int t=root,ret=0;
while(t)
{
if(v<=val[t])
{
t=lch[t];
}
else
{
ret+=w[lch[t]]+1;
t=rch[t];
}
}
return ret+1;
}
int pre(int v)
{
int t=root,ret=0;
while(t)
{
if(v<=val[t])
{
t=lch[t];
}
else
{
ret=t;
t=rch[t];
}
}
return val[ret];
}
int suc(int v)
{
int t=root,ret=0;
while(t)
{
if(v>=val[t])
{
t=rch[t];
}
else
{
ret=t;
t=lch[t];
}
}
return val[ret];
}
int main()
{
int n,x,y;
scanf("%d",&n);
while(n--)
{
scanf("%d%d",&x,&y);
switch(x)
{
case 1:
insert(y);
break;
case 2:
erase(y);
break;
case 3:
printf("%d\n",rank(y));
break;
case 4:
printf("%d\n",select(y));
break;
case 5:
printf("%d\n",pre(y));
break;
case 6:
printf("%d\n",suc(y));
break;
}
}
}