·4.7 伸展树(splay)
- 结构体: SplayTree
- 成员函数:
void split(int &x,int &y,int a);
//x 为分裂前的树根,a为分裂之后的前半部分部分大小
//输出 x,y 分别为两颗树的树根
时间复杂度:
O(logN)
void join(int &x,int &y);
//x,y为合并之前的两棵树的树根
//x为合并之后的树根
时间复杂度:
O(logN)
int getRank(int &x);
//输入x为splay中的节点
//x在splay中的排名
时间复杂度:
O(logN)
void split3(int &x,int &y,int &z,int a,int b);
//x,y,z三棵树的树根 x为前a-1个,y为a-b,z为b之后的
//x,y,z 三棵树的树根
时间复杂度:
O(logN)
void join3(int &x,int y,int z);
//x 为合并后的树根
时间复杂度:
O(logN)
void reverse(int a,int b);
//将区间[a,b]翻转
时间复杂度:
O(logN)
const int MAXN = 10000;
struct SplayTree
{
int nodeCnt, root, type[MAXN], parent[MAXN], childs[MAXN][2],
size[MAXN], stack[MAXN], reversed[MAXN];
int val[MAXN];
void clear()
{
root = size[0] = 0;
nodeCnt = 1;
}
int malloc(int v)
{
type[nodeCnt] = 2;
val[nodeCnt] = v;
childs[nodeCnt][0] = childs[nodeCnt][1] = 0;
size[nodeCnt] = 1; reversed[nodeCnt] = 0;
return nodeCnt++;
}
void update(int x)
{
size[x] = size[childs[x][0]] + 1 + size[childs[x][1]];
}
void pass(int x)
{
if (reversed[x])
{
swap(childs[x][0], childs[x][1]);
type[childs[x][0]] = 0;
reversed[childs[x][0]] ^= 1;
type[childs[x][1]] = 1;
type[childs[x][1]] ^= 1;
reversed[x] = 0;
}
}
/*
void insert(int x,int v)
{
int y=(v>val[x]);
if(childs[x][y])insert(childs[x][y],v);
else
{
int z=malloc(v);
parent[z]=x;
childs[x][y]=z;
}
update(x);
}
void insert(int v)
{
int x=root;
while(childs[x][v>val[x]])x=childs[x][v>val[x]];
childs[x][v>val[x]]=malloc(v);
update(x);
splay(childs[x][v>val[x]]);
}
*/
void rotate(int x)
{
int t = type[x],
y = parent[x],
z = childs[x][1 - t];
type[x] = type[y];
parent[x] = parent[y];
if (type[x] != 2)childs[parent[x]][type[x]] = x;
type[y] = 1 - t;
parent[y] = x;
childs[x][1 - t] = y;
if (z)type[z] = t, parent[z] = y;
childs[y][t] = z;
update(y);
}
void splay(int x)
{
int stackCnt = 0;
stack[stackCnt++] = x;
for (int i = x; type[i] != 2; i = parent[i])
stack[stackCnt++] = parent[i];
for (int i = stackCnt - 1; i >= 0; i--)
pass(stack[i]);
while (type[x] != 2)
{
int y = parent[x];
if (type[x] == type[y])
rotate(y);
else rotate(x);
if (type[x] == 2)break;
rotate(x);
}
update(x);
}
int find(int x, int rank)
{
while (true)
{
pass(x);
if (size[childs[x][0]] + 1 == rank)break;
if (rank <= size[childs[x][0]])x = childs[x][0];
else rank -= size[childs[x][0]] + 1, x = childs[x][1];
}
return x;
}
void spilt(int &x, int &y, int a)
{
y = find(x, a + 1);
splay(y);
x = childs[y][0];
type[x] = 2;
childs[y][0] = 0;
update(y);
}
void spilt3(int &x, int &y, int &z, int a, int b)
{
spilt(x, z, b);
spilt(x, y, a - 1);
}
void join(int &x, int y)
{
x = find(x, size[x]);
splay(x);
childs[x][1] = y;
type[y] = 1;
parent[y] = x;
update(x);
}
void join3(int &x, int y, int z)
{
join(y, z);
join(x, y);
}
int getRank(int x)
{
splay(x);
root = x;
return size[childs[x][0]];
}
void reverse(int a, int b)
{
int x, y;
spilt3(root, x, y, a + 1, b + 1);
reversed[x] ^= 1;
join3(root, x, y);
}
};