【模板】Splay
S p l a y Splay Splay,名为伸展树,通过不断将某个节点旋转到根节点,使得整棵树保持 B S T BST BST性质的同时,尽量平横。
维护的信息
r t rt rt | t o t tot tot | f a [ i ] fa[i] fa[i] | c h [ i ] [ 0 / 1 ] ch[i][0/1] ch[i][0/1] | v a l [ i ] val[i] val[i] | c n t [ i ] cnt[i] cnt[i] | s i z [ i ] siz[i] siz[i] |
---|---|---|---|---|---|---|
根结点编号 | 节点总数 | 父亲节点 | 左右儿子节点 | 节点权值 | 权值出现次数 | 子树大小 |
类定义
class Splay {
private:
int rt, tot, fa[maxn], ch[maxn][2], cnt[maxn], siz[maxn];
public:
int val[maxn];
};
操作
基本操作
m a i n t a i n ( x ) maintain(x) maintain(x) : 更新节点 x x x 的 s i z siz siz 值
g e t ( x ) get(x) get(x) : 判断节点 x x x 是其父亲的左儿子还是右儿子
c l e a r ( x ) clear(x) clear(x) : 清除节点 x x x 的数据
void maintain(int x) {
siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + cnt[x]; }
bool get(int x) {
return x == ch[fa[x]][1]; }
void clear(int x) {
ch[x][0] = ch[x][1] = fa[x] = cnt[x] = siz[x] = val[x] = 0; }
旋转操作
r o t a t e ( x ) rotate(x) rotate(x) : 将 x x x 向上旋转
例如上面这张图,将左图的2右旋,就可以得到右图
旋转的一般操作是(设需要旋转的节点为 x x x ,其父亲节点为 y y y ,下面以右旋为例):
- 将 y y y 的左儿子指向 x x x 的右儿子, x x x 的右儿子(如果有)的父亲指向 y y y
- 将 x x x 的右儿子指向 y y y ,且 y y y 的父亲指向 x x x
- 如果 y y y 存在父亲 z z z ,则另原来 y y y 位置指向 x x x , x x x 的父亲指向 z z z
- 最后维护节点的 s i z siz siz ,注意顺序,先 x x x 后 y y y
void rotate(int x) {
int y = fa[x], z = fa[y], chk = get(x);
ch[y][chk] = ch[x][chk ^ 1];
if(ch[x][chk ^ 1]) fa[ch[x][chk ^ 1]] = y;
ch[x][chk ^ 1] = y;
fa[y] = x; fa[x] = z;
if(z) ch[z][y == ch[z][1]] = x;
maintain(x); maintain(y);
}
伸展操作
s p l a y ( x ) splay(x) splay(x) : 将 x x x 旋转到根
根据 s p l a y splay splay 的性质,每次访问完当前节点后,要强制将其旋转到根节点,情况共有6种:
- x x x 的父亲为根节点,直接将 x x x 左旋或右旋
- x x x 的父亲节点不是根节点,但 x x x 和父亲的儿子类型相同,先将其父亲左旋或右旋,然后将 x x x 右旋或左旋
- x x x 的父亲节点不是根节点,且 x x x 和父亲的儿子类型不同,将 x x x 左旋再右旋,或者右旋再左旋
void splay(int x) {
for (int f = fa[x]; f = fa[x], f; rotate(x))
if(fa[f]) rotate(get(x) == get(f) ? f : x);
rt = x;
}
插入操作
i n s e r t ( x ) insert(x) insert(x) : 将 x x x 插入到 s p l a y splay splay 中
插入分以下几种情况:
- 树为空,直接插入根节点
- 当前节点的 v a l val val 等于 k k k ,增加当前节点的大小,并更新其和其父亲节点的信息,对当前节点 s p l a y splay splay
- 按照 B S T BST BST 性质向下找,找到空节点插入即可
void insert(int k) {
if(!rt) {
val[++tot] = k;
cnt[tot]++;
rt = tot;
maintain(rt);
return;
}
int cur = rt, f =