之所以有这篇博客是因为网上的 AVL 树代码普遍很长,讲解也很复杂,实际上这是一个很显然且容易编写的数据结构。本篇博客使用数组实现 AVL 树,指针同理,可以用作 AVL 树入门。
思想
平衡树是二叉搜索树(BST)的改进,二叉搜索树每个点维护 v a l u val_u valu 表示结点 u u u 的权值, l s x ls_x lsx 和 r s x rs_x rsx 分别表示结点 u u u 的左儿子和右儿子,满足 v a l l s x < v a l x < v a l r s x val_{ls_x}<val_x<val_{rs_x} vallsx<valx<valrsx。即左儿子权值小于父亲权值小于右儿子权值。
BST 在极端数据下会出现失衡的状况,于是有了平衡树。
AVL 树是平衡树的一种,令重量 s z x sz_x szx 表示以 x x x 为根的子树大小,则 ∣ s z l s x − s z r s x ∣ ≤ 1 |sz_{ls_x}-sz_{rs_x}|\le1 ∣szlsx−szrsx∣≤1,即左右儿子重量之差的绝对值不超过 1 1 1。
AVL 树通过旋转来维护树的平衡。
旋转
AVL 树的旋转只有一种,就是单旋,对于结点 x x x 和其父亲 f a fa fa,将结点 x x x 向上旋转就是让 x x x 作为 f a fa fa 的父亲,原来 x x x 连接的一个儿子 v v v 作为 f a fa fa 的儿子。
初学建议画图。
另外,可能有重复的值,用 c n t x cnt_x cntx 表示 v a l x val_x valx 的个数, s z x sz_x szx 表示子树 x 值的个数, n s z x nsz_x nszx 表示子树 x x x 结点个数, c h x , 0 / 1 ch_{x,0/1} chx,0/1 表示结点 x 的左 / 右儿子。
// 维护 sz 和 nsz
void pushup(int x) { sz[x] = sz[ls] + sz[rs] + cnt[x], nsz[x] = nsz[ls] + nsz[rs] + 1; }
void spin(int x, int & fa) { // 旋转操作
int d = ch[fa][1] == x; // d 表示 x 是 fa 的左/右儿子
ch[fa][d] = ch[x][d ^ 1], ch[x][d ^ 1] = fa;
pushup(fa), pushup(x), fa = x;
}
维持平衡
当不满足 ∣ n s z l s x − n s z r s x ∣ ≤ 1 |nsz_{ls_x}-nsz_{rs_x}|\le1 ∣nszlsx−nszrsx∣≤1 时,则将其进行旋转。
正常来说,应该先将其转成都往一边偏的形状,再将结点 x x x 重的儿子进行旋转:
x x d
/ \ / \ / \
a b d b a x
/ \ --> / \ --> / \ / \
c d a f c e f b
/ \ / \
e f c e
代码理应是:
void maintain(int & x) { // 维持平衡
int d = abs(nsz[ls] - nsz[rs]) > 1 ? nsz[ls] < nsz[rs] : -1;
if(d == -1) return;
int & c = ch[x][d];
if(nsz[ch[c][d]] < nsz[ch[c][d ^ 1]]) spin(ch[c][d ^ 1], c);
spin(c, x);
}
但实际测试和直接将结点 x x x 的重的儿子进行旋转效率差不多,甚至后者还更快(这里当初写错了几次,结果改的复杂度越正确跑得越慢):
void maintain(int & x) { // 维持平衡
if(nsz[ls] > nsz[rs] + 1) spin(ls, x);
if(nsz[rs] > nsz[ls] + 1) spin(rs, x);
}
因此采用下面这种。
插入
现在要插入值 v v v,对于结点 x x x,若 v < v a l x v<val_x v<valx,则向左递归;若 v > v a l x v>val_x v>valx,则向右递归,若 v = v a l x v=val_x v=valx,则令 c n t x ← c n t x + 1 cnt_x\leftarrow cnt_x+1 cntx←cntx+1 并返回。当递归到叶子结点时将其作为叶子节点的儿子。
// 添加结点,cl 用来回收删除的结点(不持久化的话其实没必要)
void add_node(int & x, int v) { val[x = cln ? cl[cln--] : ++tot] = v, ls = rs = 0, sz[x] = nsz[x] = 1, cnt[x] = 1; }
void ins(int & x, int v) { // 插入
if(!x) return add_node(x, v), void();
if(v == val[x]) return ++cnt[x], ++sz[x], void();
v < val[x] ? ins(ls, v) : ins(rs, v);
maintain(x), pushup(x);
}
删除
思路是找到了要删除的结点 x x x 后,若 c n t x > 1 cnt_x>1 cntx>1,则将其减一,否则将 x x x 旋转成叶子结点并删除。
若旋转到某一时刻使得 x x x 没有左儿子或者没有右儿子,可以直接删除结点 x x x 并返回。
尽量旋转 n s z nsz nsz 较大的儿子,这样能保证树不会失衡。
void del(int & x, int v) { // 删除
if(!x) return;
if(v == val[x]) { // 找到了结点 x
if(cnt[x] > 1) return --cnt[x], --sz[x], void(); // cnt[x]>1 就直接将其减 1
if(!ls || !rs) return cl[++cln] = x, x = ls | rs, void(); // 没有左儿子或右儿子就直接删除结点 x
spin(sz[ls] > sz[rs] ? ls : rs, x), del(x, v); // 否则将其旋转
} else del(v < val[x] ? ls : rs, v);
pushup(x);
}
获取排名
现在要获取值 v v v 的排名(比 v v v 小的数的个数 + 1 +1 +1),可以在树上二分,具体见代码。
int grk(int x, int v) { // 获取比 v 小的数的个数
if(!x) return 0;
if(v <= val[x]) return grk(ls, v); // 若 v<=val[x],则递归左子树
return grk(rs, v) + sz[ls] + cnt[x]; // 否则,答案为递归右子树得到的答案,加上左子树大小和 cnt[x]
}
第 k 大 / 小
以第 k 小为例,同样在树上二分。
int kth(int x, int rk) { // 获取第 rk 小
if(ls && rk <= sz[ls]) return kth(ls, rk);
if(rk <= sz[ls] + cnt[x]) return val[x];
return kth(rs, rk - sz[ls] - cnt[x]);
}
求前驱 / 后继
一个数的前驱指的是比这个数小且最大的数,后继指的是比这个数大且最小的数。
怎么求呢?有个简单方法,比如现在 AVL 树维护的是 { 1 , 2 , 2 , 2 , 3 , 3 , 3 , 4 , 5 } \{1,2,2,2,3,3,3,4,5\} {1,2,2,2,3,3,3,4,5},如果要求 3 3 3 的前驱,那么先求出比 3 3 3 小的数的个数,即 { 1 , 2 , 2 , 2 } \{1,2,2,2\} {1,2,2,2} 4 个,然后求第 4 4 4 小,就求出了 2 2 2,即 3 3 3 的前驱。求后继同理。
这样写经实际测试常数并不是很大,可以接受。
int pre(int v) { return kth(rt, grk(rt, v)); } // 求前驱
int nxt(int v) { return kth(rt, grk(rt, v + 1) + 1); } // 求后继
完整代码
const int N = 1100010;
namespace AVL {
int val[N], ch[N][2], sz[N], nsz[N], tot, cnt[N], rt;
int cl[N], cln;
void pushup(int x) { sz[x] = sz[ls] + sz[rs] + cnt[x], nsz[x] = nsz[ls] + nsz[rs] + 1; }
void spin(int x, int & fa) {
int d = ch[fa][1] == x;
ch[fa][d] = ch[x][d ^ 1], ch[x][d ^ 1] = fa;
pushup(fa), pushup(x), fa = x;
}
void maintain(int & x) {
if(nsz[ls] > nsz[rs] + 1) spin(ls, x);
if(nsz[rs] > nsz[ls] + 1) spin(rs, x);
}
void add_node(int & x, int v) { val[x = cln ? cl[cln--] : ++tot] = v, ls = rs = 0, sz[x] = nsz[x] = 1, cnt[x] = 1; }
void ins(int & x, int v) {
if(!x) return add_node(x, v), void();
if(v == val[x]) return ++cnt[x], ++sz[x], void();
v < val[x] ? ins(ls, v) : ins(rs, v);
maintain(x), pushup(x);
}
void del(int & x, int v) {
if(!x) return;
if(v == val[x]) {
if(cnt[x] > 1) return --cnt[x], --sz[x], void();
if(!ls || !rs) return cl[++cln] = x, x = ls | rs, void();
spin(sz[ls] > sz[rs] ? ls : rs, x), del(x, v);
} else del(v < val[x] ? ls : rs, v);
pushup(x);
}
int grk(int x, int v) {
if(!x) return 0;
if(v <= val[x]) return grk(ls, v);
return grk(rs, v) + sz[ls] + cnt[x];
}
int kth(int x, int rk) {
if(ls && rk <= sz[ls]) return kth(ls, rk);
if(rk <= sz[ls] + cnt[x]) return val[x];
return kth(rs, rk - sz[ls] - cnt[x]);
}
int pre(int v) { return kth(rt, grk(rt, v)); }
int nxt(int v) { return kth(rt, grk(rt, v + 1) + 1); }
}
共 54 行且没有刻意压行,应该算是比较简短的,相比 Splay,替罪羊树,Treap,FHQ Treap,WBLT 之类的好写多了,效率很可观。