在各种各样的数据结构中,有很多树,可以查找第k大的数,比如划分树,查找区间最大值,比如线段树,但是绝大部分的树形数据结构都不能进行区间删除,但是有一种数据结构能进行区间删除,还能进行树的合并,它既是伸展树。
伸展树跟平衡树等类似,也是一颗有序的二叉树,也有左旋右旋操作,查找插入删除等操作的平均复杂度也是log(n)级别的,他的特点就是在进行在对某个数进行操作时会将这个数旋转到树根的地方,这一特点就能让伸展树很灵活,例如对于两颗树的合并,先将这两颗树的最大值旋转到树根的地方,然后这两颗树就变成了根只有右儿子的树了,这样进行合并就简单的多了。
struct node {
node *left, *right;
node *par; //父节点
int val, w; // 节点所存的值和个数
};
然后可以试着写一下左旋还有右旋的函数,方便以后直接用。
左旋跟右旋一样,思想就是在不更改元素大小位置的前提下将某个元素往上移动。
可以画个图试一下,图以后再补。
左旋:
void LeftRotate(node *x) { // 将x节点旋转到他的父节点的位置
node *y = x->par;
if (x->left) x->left->par = y;
y->right = x->left;
x->left = y;
x->par = y->par;
if (y->par) {
if (y->par->right == y) y->par->right = x;
else y->par->left = x;
}
y->par = x;
}
右旋:
void RightRotate(node *x) {
node *y = x->par;
if (x->right) x->right->par = y;
y->left = x->right;
x->right = y;
x->par = y->par;
if (y->par) {
if (y->par->right == y) y->par->right = x;
else y->par->left = x;
}
y->par = x;
}
然后就是伸展操作了,这个操作就是将某个节点移动到另一个节点的儿子节点上。具体看代码讲解。
void splay(node *x, node *y) {
while (x->par != y) {
if (x->par->par == y) x->par->right == x ? LeftRotate(x) : RightRotate(x); //这里分成了两种情况,一种是只需要一步旋转即可完成x旋转到y下面,第二种就是四种情况,即左左,左右,右右,右左,以后再补图
else {
if (x->par->par->right == x->par) {
if (x->par->right == x) { // 右右
LeftRotate(x);
LeftRotate(x);
} else { //右左
RightRotate(x);
LeftRotate(x);
}
} else { // 左右
if (x->par->right == x) {
LeftRotate(x);
RightRotate(x);
} else { // 左左
RightRotate(x);
RightRotate(x);
}
}
}
}
if (y == 0) root = x; //别忘了根节点
}
以上就是伸展所需要的最基本的操作了,可以利用这些操作干很多事。
比如插入操作。
void Insert(int val) {
node *p = root, *tmp = NULL;
while (true) {
if (p == NULL) {
p = new node;
p->left = p->right = NULL;
p->val = val;
p->w = 1;
p->par = tmp;
if (tmp) {
if (tmp->val > val) tmp->left = p;
else tmp->right = p;
}
splay(p, 0);
break;
} else {
tmp = p;
if (p->val == val) {
p->w++;
splay(p, 0);
break;
}
else if (p->val > val) p = p->left;
else p = p->right;
}
}
}
还有删除操作。
void join(node *x, node *y) { // 这个函数是将以x为根的树和以y为根的树合并,很常用
if (!x) {
root = y;
return ;
}
if (!y) {
root = x;
return ;
}
node *a = y;
while (a->right) a = a->right;
splay(a, 0);
x->par = a;
a->left = x;
root = a;
}
void Remove(node *rt, int val) {
// 删除操作实际上是将要删除的元素旋转到根,然后进行删除
if (val == rt->val) {
if (rt->w > 1) rt->w--;
else {
splay(rt, 0);
node *a = rt->left, *b = rt->right;
rt->left = rt->right = 0;
if (a) a->par = 0;
if (b) b->par = 0;
join(a, b);
delete rt;
}
} else if (rt->val > val) Remove(rt->left, val);
else Remove(rt->right, val);
}
代码挺好理解的。以后还会继续补充。
总代码:
#include <iostream>
#include <algorithm>
using namespace std;
struct node {
node *left, *right;
node *par;
int val, w;
};
struct node *root;
void LeftRotate(node *x) {
node *y = x->par;
if (x->left) x->left->par = y;
y->right = x->left;
x->left = y;
x->par = y->par;
if (y->par) {
if (y->par->right == y) y->par->right = x;
else y->par->left = x;
}
y->par = x;
}
void RightRotate(node *x) {
node *y = x->par;
if (x->right) x->right->par = y;
y->left = x->right;
x->right = y;
x->par = y->par;
if (y->par) {
if (y->par->right == y) y->par->right = x;
else y->par->left = x;
}
y->par = x;
}
void splay(node *x, node *y) {
while (x->par != y) {
if (x->par->par == y) x->par->right == x ? LeftRotate(x) : RightRotate(x);
else {
if (x->par->par->right == x->par) {
if (x->par->right == x) {
LeftRotate(x);
LeftRotate(x);
} else {
RightRotate(x);
LeftRotate(x);
}
} else {
if (x->par->right == x) {
LeftRotate(x);
RightRotate(x);
} else {
RightRotate(x);
RightRotate(x);
}
}
}
}
if (y == 0) root = x;
}
void Insert(int val) {
node *p = root, *tmp = NULL;
while (true) {
if (p == NULL) {
p = new node;
p->left = p->right = NULL;
p->val = val;
p->w = 1;
p->par = tmp;
if (tmp) {
if (tmp->val > val) tmp->left = p;
else tmp->right = p;
}
splay(p, 0);
break;
} else {
tmp = p;
if (p->val == val) {
p->w++;
splay(p, 0);
break;
}
else if (p->val > val) p = p->left;
else p = p->right;
}
}
}
void print(node *rt) {
if (rt->left != NULL) print(rt->left);
printf("%d ", rt->val);
if (rt->right != NULL) print(rt->right);
}
void join(node *x, node *y) {
if (!x) {
root = y;
return ;
}
if (!y) {
root = x;
return ;
}
node *a = y;
while (a->right) a = a->right;
splay(a, 0);
x->par = a;
a->left = x;
root = a;
}
void Remove(node *rt, int val) {
if (val == rt->val) {
if (rt->w > 1) rt->w--;
else {
splay(rt, 0);
node *a = rt->left, *b = rt->right;
rt->left = rt->right = 0;
if (a) a->par = 0;
if (b) b->par = 0;
join(a, b);
delete rt;
}
} else if (rt->val > val) Remove(rt->left, val);
else Remove(rt->right, val);
}
int n, m;
int main() {
int x;
scanf("%d", &n);
for (int i = 0; i < n; ++i) {
scanf("%d", &x);
Insert(x);
}
scanf("%d", &m);
for (int i = 0; i < m; ++i) {
scanf("%d", &x);
Remove(root, x);
}
print(root);
printf("\n");
return 0;
}