非旋 Treap 的每个操作的时间复杂度是 O(logn)
的
Luogu P3369 【模板】普通平衡树(Treap/SBT)
- 插入 x 数
- 删除 x 数(若有多个相同的数,因只删除一个)
- 查询 x 数的排名(排名定义为比当前数小的数的个数+1。若有多个相同的数,因输出最小的排名)
- 查询排名为 k 的数
- 求 x 的前驱(前驱定义为小于 x,且最大的数)
- 求 x 的后继(后继定义为大于 x,且最小的数)
Merge
Split
Insert
新插入一个权值为 x 的点,找 x 的排名时,如果 Treap 里已有权值也为 x 的点,就返回 < x 的数的个数 + 1。
否则返回 < x 的数的个数。
erase
(更正 : k 应该是 < x 的个数
指针版:
#include <bits/stdc++.h>
using namespace std;
struct Treap {
struct Node {
Node *lc, *rc;
int x, size, key; // key 为优先级
Node(int x) : lc(NULL), rc(NULL), x(x), size(1), key(rand()) {}
inline void maintain() { size = (lc ? lc->size : 0) + (rc ? rc->size : 0) + 1; }
inline int lSize() { return lc ? lc->size : 0; }
} *root;
// 合并
// 需要保证 a 中的所有节点比 b 中的小
Node *merge(Node *a, Node *b) {
if (!a && !b) return NULL;
if (!a) { b->maintain(); return b; } //
if (!b) { a->maintain(): return a; }
if (a->key > b->key) {
a->rc = merge(a->rc, b); // 一定要注意是 merge(a->rc, b) 而不是 merge(b, a->rc),显然仍然要保证前一个 Treap 小于后一个 Treap。
a->maintain();
return a;
} else {
b->lc = merge(a, b->lc);
b->maintain();
return b;
}
}
// 将 v 这棵子树分离成两棵,前 k 个节点作为 l 返回,之后的节点作为 r 返回
inline void split(Node *v, int k, Node *&l, Node *&r) {
if (!v) { l = r = NULL; return ;}
int s = v->lSize();
if(k <= s) {
split(v->lc, k, l, r);
v->lc = r; // 将返回的不在前 k 个范围内的 r 接到 v 的左儿子,为什么不是右儿子呢,因为右儿子有东西了啊
r = v; // 将剩余的 v 作为 r 返回,属于前 k 个数的已经在上两步被 split 掉了,所以剩下的 v 就是不在前 k 个数中的了
} else {
split(v->rc, k - s - 1, l, r); // 注意节点数要改变
v->rc = l;
l = v;
}
// 注意重新维护信息
v->maintain();
}
// 返回有多少个 < x 的数
inline int lowerCount(int x) {
Node *v = root;
int res = 0;
while(v) {
// 如果 v->x == x,走前一个分支,不将当前点加到答案中
if (x <= v->x) v = v->lc;
else {
res += v->lSize() + 1;
v = v->rc;
}
}
return res;
}
// 返回有多少个 <= x 的数
inline int upperCount(int x) {
Node *v = root;
int res = 0;
while(v) {
// 如果 v->x == x,走后一个分支,将当前点加到答案中
if (x < v->x) v = v->lc;
else {
res += v->lSize() + 1;
v = v->rc;
}
}
return res;
}
inline void insert(int x) {
int cnt = lowerCount(x);
Node *l, *r;
split(root, cnt, l, r);
// 新节点
Node *v = new Node(x);
root = merge(merge(l, v), r);
}
inline void erase(int x) {
int cnt = lowerCount(x);
// pred 为 x 的前趋及之前的所有节点
// tmp 为最靠前的一个等于 x 的节点(及其子树)
Node *pred, *tmp;
split(root, cnt, pred, tmp);
// succ 为不需要删除的多余的 x 和 x 的后继及之后的所有节点
// target 表示实际删除的唯一一个节点
Node *target, *succ;
split(tmp, 1, target, succ); // 注意这里要分离的是 tmp
root = merge(pred, succ);
}
// 查询排名为 k 的数
inline int select(int k) {
Node *v = root;
while(v->lSize() != k - 1) {
if(v->lSize() > k - 1) v = v->lc;
else {
k -= v->lSize() + 1;
v = v->rc;
}
}
return v->x;
}
inline int rank(int x) {
return lowerCount(x) + 1;
}
inline int pred(int x) {
return select(lowerCount(x));
}
inline int succ(int x) {
return select(upperCount(x) + 1);
}
}treap;
int main() {
int n;
scanf("%d", &n);
while (n --) {
int opt, x;
scanf("%d%d", &opt, &x);
if (opt == 1) treap.insert(x);
else if (opt == 2) treap.erase(x);
else if (opt == 3) printf("%d\n", treap.rank(x));
else if (opt == 4) printf("%d\n", treap.select(x));
else if (opt == 5) printf("%d\n", treap.pred(x));
else if (opt == 6) printf("%d\n", treap.succ(x));
}
return 0;
}
数组版:
#include <bits/stdc++.h>
#define mp make_pair
using namespace std;
const int inf = 0x7fffffff, N = 1e5 + 5;
typedef pair<int, int> par;
int rt = 0, cnt = 0;
struct Treap {
int ls[N], rs[N], size[N], prio[N], w[N];
inline void maintain(int p) { size[p] = size[ls[p]] + size[rs[p]] + 1; }
inline int merge(int x, int y) {
if (!x) { maintain(y); return y; }
if (!y) { maintain(x); return x; }
if (prio[x] < prio[y]) { rs[x] = merge(rs[x], y); maintain(x); return x; }
else { ls[y] = merge(x, ls[y]); maintain(y); return y; }
}
par split(int p, int k) {
if (!k) return mp(0, p);
if (k <= size[ls[p]]) {
par tem = split(ls[p], k);
ls[p] = tem.second;
maintain(p);
return mp(tem.first, p);
}
par tem = split(rs[p], k - size[ls[p]] - 1);
rs[p] = tem.first;
maintain(p);
return mp(p, tem.second);
}
int queryrank(int p, int x) {
int ans = 0, tem = inf;
while (p) {
if (x == w[p]) tem = min(tem, ans + size[ls[p]] + 1);
if (x > w[p]) ans += size[ls[p]] + 1, p = rs[p];
else p = ls[p];
}
return (tem == inf) ? ans : tem;
}
inline void insert(int x) {
int k = queryrank(rt, x);
par tem = split(rt, k);
w[++ cnt] = x, size[cnt] = 1, prio[cnt] = rand();
rt = merge(tem.first, cnt), rt = merge(rt, tem.second);
}
void erase(int x) {
int k = queryrank(rt, x);
par p1 = split(rt, k - 1); par p2 = split(p1.second, 1);
rt = merge(p1.first, p2.second);
}
int select(int p, int k) {
for (; ; ) {
if (k == size[ls[p]] + 1) return w[p];
if (k < size[ls[p]] + 1) p = ls[p];
else k -= size[ls[p]] + 1, p = rs[p];
}
}
int pred(int p, int x) {
int ans = -inf;
while(p) {
if(x > w[p]) ans = max(ans, w[p]), p = rs[p];
else p = ls[p];
}
return ans;
}
int succ(int p, int x) {
int ans = inf;
while(p) {
if(x < w[p]) ans = min(ans, w[p]), p = ls[p];
else p = rs[p];
}
return ans;
}
} treap;
int main() {
int n;
scanf("%d", &n);
while (n --) {
int opt, x;
scanf("%d%d", &opt, &x);
if (opt == 1) treap.insert(x);
else if (opt == 2) treap.erase(x);
else if (opt == 3) printf("%d\n", treap.queryrank(rt, x));
else if (opt == 4) printf("%d\n", treap.select(rt, x));
else if (opt == 5) printf("%d\n", treap.pred(rt, x));
else if (opt == 6) printf("%d\n", treap.succ(rt, x));
}
return 0;
}