零、碎碎念
打比赛没遇上可持久化Trie,做个CMU 15-445的project0,上来就碰上了……
关于Trie详见:[Trie树/字典树的原理及实现C/C++]_trie字典树原理-CSDN博客
一、可持久化Trie
1.1 基本思想
可持久化Trie和可持久化线段树类似,因为每次插入只有一条路径走到底,所以不需要每个版本开一棵树。
比如下面就是在01Trie上依次插入[2, 5, 7]的三个版本
因而,我们动态开点,在上一个版本的基础上,增加新的节点,就得到了新版本的Trie。
为了方便叙述,下面都以01Trie为例。
1.2 Trie基本结构
struct Trie{
static constexpr int ALPHABET = 2; // 字符集
static constexpr int B = 24; // 二进制位范围
struct Node{ // 结点定义
Node():cnt(0), son{} {}
std::array<int, ALPHABET> son;
int cnt;
};
std::vector<Node> tr; // 结点池
std::vector<int> root; // 各版本根节点
Trie() {
tr.emplace_back(Node());
root.emplace_back(0); // 初始化空节点0 为 0 号版本
}
int newNode(){ // 动态开点
tr.emplace_back();
return (int)tr.size() - 1;
}
void add(int v) {}
int max_xor(int x, int y, int v) {}
};
1.3 插入操作
- 每插入一个新数字v,都会生成一个新版本的Trie
- 记新版本编号为y,上一个版本编号为x
- 按位从高到低遍历v,记当前遍历到第 i 位(位从0开始编号)
- 令 j = v >> i & 1,那么 j 就是v在 第 i 位的值,即 tr[y].son[j] 是我们要生成的结点
- 开新点给tr[y].son[i],tr[y].son[!j] 继承 tr[x].son[!j]
- 遍历完位,插入结束
- 时间复杂度:O(log v),每个版本只开辟了O(log v)个新结点
代码实现
void add(int v) {
int x = root.back(), y = newNode();
root.emplace_back(y);
for (int i = B - 1; ~i; -- i) {
int j = v >> i & 1;
tr[y].son[!j] = tr[x].son[!j];
tr[y].son[j] = newNode();
x = tr[x].son[j], y = tr[y].son[j];
tr[y].cnt = tr[x].cnt + 1;
}
}
1.4 查询操作
01 Trie 的查询操作一般都是查询最大异或和。
(两两异或第K大查询见OJ练习2.2)
可持久化Trie 支持我们查询任意区间内子序列和任意数字 v 的最大异或和
- 查询区间[l, r] 对应版本 [l, r],待查询数字v,返回结果为res
- 令 y = root[r],x = root[l - 1],x 显然是边界,我们不能伸入x以及x左边的版本
- 按位从高到低遍历v,记当前遍历到第 i 位
- 令j = v >> i & 1
- 如果 tr[tr[y].son[!j]].cnt > tr[tr[x].son[!j]].cnt,说明 !j 这条路径上有结点,并且未伸入边界,我们就令 x = tr[x].son[!j], y = tr[y].son[!j],res |= 1 << i
- 否则 x = tr[x].son[j],y = tr[y].son[j]
- 遍历结束,返回res
代码实现
int max_xor(int x, int y, int v) {
x = root[x], y = root[y];
int res = 0;
for (int i = B - 1; ~i; -- i) {
int j = v >> i & 1;
if (tr[tr[y].son[!j]].cnt > tr[tr[x].son[!j]].cnt) {
res |= 1 << i;
j ^= 1;
}
y = tr[y].son[j];
x = tr[x].son[j];
}
return res;
}
1.5 完整代码
其它功能,根据不同题目,分析编写即可。
struct Trie{
static constexpr int ALPHAEBT = 2;
static constexpr int B = 24;
struct Node{
Node():cnt(0), son{} {}
std::array<int, ALPHAEBT> son;
int cnt;
};
std::vector<Node> tr;
std::vector<int> root;
Trie() {
tr.emplace_back(Node());
root.emplace_back(0);
}
int newNode(){
tr.emplace_back();
return (int)tr.size() - 1;
}
void add(int v) {
int x = root.back(), y = newNode();
root.emplace_back(y);
for (int i = B - 1; ~i; -- i) {
int j = v >> i & 1;
tr[y].son[!j] = tr[x].son[!j];
tr[y].son[j] = newNode();
x = tr[x].son[j], y = tr[y].son[j];
tr[y].cnt = tr[x].cnt + 1;
}
}
int max_xor(int x, int y, int v) {
x = root[x], y = root[y];
int res = 0;
for (int i = B - 1; ~i; -- i) {
int j = v >> i & 1;
if (tr[tr[y].son[!j]].cnt > tr[tr[x].son[!j]].cnt) {
res |= 1 << i;
j ^= 1;
}
y = tr[y].son[j];
x = tr[x].son[j];
}
return res;
}
};
二、OJ练习
2.1 最大异或和
原题链接
P4735 最大异或和 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
思路分析
如果没学过可持久化Trie,我大概会离线处理 + 01Trie + 前缀和来做
但现在不同了,我们可以用可持久化Trie + 前缀和轻松解决
先在Trie中插入0,这是前缀异或和都要设置的哨兵
插入的部分我们选择插入前缀异或和,后面会用到
我们记 前i个数异或和为 s[i]
对于查询的部分,因为插入了0,所以 [l, r] 对应 版本/区间 [l + 1, r + 1]
因为我们选取的后缀不能空,所以相当于 求 s[p] ^ s[p + 1] ^ … ^ s[r] ^ x 的最值(即a[r + 1] 必须取)
然后查询 [l, r] 内和v 的最大异或和即可
因为查询区间是[l, r],所以左边界应该是root[l - 1]
AC代码
#include <bits/stdc++.h>
using i64 = long long;
using u32 = unsigned int;
using u64 = unsigned long long;
struct Trie{
static constexpr int ALPHAEBT = 2;
static constexpr int B = 24;
struct Node{
Node():cnt(0), son{} {}
std::array<int, ALPHAEBT> son;
int cnt;
};
std::vector<Node> tr;
std::vector<int> root;
Trie() {
tr.emplace_back(Node());
root.emplace_back(0);
}
int newNode(){
tr.emplace_back();
return (int)tr.size() - 1;
}
void add(int v) {
int x = root.back(), y = newNode();
root.emplace_back(y);
for (int i = B - 1; ~i; -- i) {
int j = v >> i & 1;
tr[y].son[!j] = tr[x].son[!j];
tr[y].son[j] = newNode();
x = tr[x].son[j], y = tr[y].son[j];
tr[y].cnt = tr[x].cnt + 1;
}
}
int max_xor(int x, int y, int v) {
x = root[x], y = root[y];
int res = 0;
for (int i = B - 1; ~i; -- i) {
int j = v >> i & 1;
if (tr[tr[y].son[!j]].cnt > tr[tr[x].son[!j]].cnt) {
res |= 1 << i;
j ^= 1;
}
y = tr[y].son[j];
x = tr[x].son[j];
}
return res;
}
};
auto FIO = []{
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
return 0;
}();
int main() {
int n, m;
std::cin >> n >> m;
Trie tr;
tr.add(0);
int s = 0;
for (int i = 0, a; i < n; ++ i)
std::cin >> a, tr.add(s ^= a);
for (int i = 0, l, r, x; i < m; ++ i) {
char op;
std::cin >> op;
if (op == 'A') {
std::cin >> x;
tr.add(s ^= x);
}
else {
std::cin >> l >> r >> x;
std::cout << tr.max_xor(l - 1, r, s ^ x) << '\n';
}
}
return 0;
}
2.2 异或粽子(kth_max_xor)
原题链接
[P5283 十二省联考 2019] 异或粽子 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
思路分析
我们选择在01Trie中插入前缀和
那么本题就转换成了求数组中前K大两数异或值之和
我们考虑固定一个右端点,如何求第k大xor?
我们Trie的结点存储了cnt,代表了该路径上该位为0 / 1的方案数
那么类似于 我们在平衡树(如Splay、Treap)上查kth
如果路径可走:
- cnt >= k,那就走
- 否则k -= cnt
我们在堆中插入n个位置以及rank = 1时的max_xor
然后弹k次,不断维护即可
时间复杂度:O(k log^2 n)
注意:本代码无法通过本题加强版:https://codeforces.com/problemset/problem/241/B
事实上,可以寻找O(n log^2 n)做法
AC代码
#include <bits/stdc++.h>
// #include <ranges>
using u32 = unsigned;
using i64 = long long;
using u64 = unsigned long long;
constexpr int P = 1'000'000'007;
struct Trie{
static constexpr int ALPHABET = 2;
static constexpr int B = 33;
struct Node{
std::array<int, ALPHABET> son;
int cnt;
Node(): son{}, cnt(0) {}
};
std::vector<Node> tr;
std::vector<int> root;
Trie(){
tr.emplace_back(Node());
root.emplace_back(0);
}
int newNode() {
tr.emplace_back();
return (int)tr.size() - 1;
}
void add(i64 v) {
int x = root.back(), y = newNode();
root.emplace_back(y);
for (int i = B - 1; ~i; -- i) {
int j = v >> i & 1;
tr[y].son[!j] = tr[x].son[!j];
tr[y].son[j] = newNode();
x = tr[x].son[j], y = tr[y].son[j];
tr[y].cnt = tr[x].cnt + 1;
}
}
i64 max_xor(int x, int y, int v) {
x = root[x], y = root[y];
i64 res = 0;
for (int i = B - 1; ~i; -- i) {
int j = v >> i & 1;
if (tr[tr[y].son[!j]].cnt > tr[tr[x].son[!j]].cnt) {
res |= 1 << i;
j ^= 1;
}
y = tr[y].son[j];
x = tr[x].son[j];
}
return res;
}
i64 max_xor(int x, int y, i64 v, int k) {
x = root[x], y = root[y];
i64 res = 0;
for (int i = B - 1; ~i; -- i) {
int j = v >> i & 1;
if (tr[tr[y].son[!j]].cnt > tr[tr[x].son[!j]].cnt) {
if (k <= tr[tr[y].son[!j]].cnt - tr[tr[x].son[!j]].cnt) {
res |= 1LL << i;
j ^= 1;
}
else
k -= tr[tr[y].son[!j]].cnt - tr[tr[x].son[!j]].cnt;
}
y = tr[y].son[j];
x = tr[x].son[j];
}
return res;
}
};
void solve() {
int n, k;
std::cin >> n >> k;
Trie tr;
std::priority_queue<std::tuple<i64, int, int, i64>> pq;
i64 s = 0;
tr.add(0);
for (int i = 0; i < n; ++ i) {
i64 a;
std::cin >> a;
tr.add(s ^= a);
pq.emplace(tr.max_xor(0, i + 2, s, 1), i + 2, 1, s);
}
i64 res = 0;
while (k --) {
auto [v, r, rank, a] = pq.top();
pq.pop();
res += v;
++ rank;
pq.emplace(tr.max_xor(0, r, a, rank), r, rank, a);
}
std::cout << res;
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int t = 1;
// std::cin >> t;
while (t--) {
solve();
}
return 0;
}
2.3 ALO
原题链接
[P4098 HEOI2013] ALO - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
思路分析
喔的写法是 单调栈 + ST表 + 可持久化Trie
单调栈处理每个下标左边第一个比自己大的,右边第一个比自己大的
题解有人用链表轻松处理左右第二个大的,但是我没看懂,所以还是写了ST表
然后我们枚举每个数,记 左边第一个大的为l,第二个为ll,同理有r,rr
那么可以作为次大值的区间就是 [ll + 1, r - 1], [l + 1, rr - 1]
在可持久化Trie上查询即可
时间复杂度:O(nlogn)
AC代码
#include <bits/stdc++.h>
// #include <ranges>
using u32 = unsigned;
using i64 = long long;
using u64 = unsigned long long;
constexpr int P = 1'000'000'007;
template<class T, class Func, const int M = 30>
struct ST {
Func F;
T n;
std::vector<T> nums;
std::vector<int> LOG2;
std::vector<std::array<T, M>> f;
ST (const std::vector<T>& _nums) : n(_nums.size()), nums(_nums), LOG2(n + 1), f(n) {
LOG2[2] = 1;
for (int i = 3; i <= n; i ++ )
LOG2[i] = LOG2[i >> 1] + 1;
for (int i = 0; i < n; i ++ )
f[i][0] = nums[i];
for (int j = 1; j < M; j ++)
for (int i = 0; i < n && i + (1 << (j - 1)) < n; i ++)
f[i][j] = F(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
T query(int l, int r) {
int k = LOG2[r - l + 1];
return F(f[l][k], f[r - (1 << k) + 1][k]);
}
};
struct Func{
int operator()(int x, int y){
return x > y ? x : y;
}
};
struct Trie{
static constexpr int ALPHABET = 2;
static constexpr int B = 30;
struct Node{
std::array<int, ALPHABET> son;
int cnt;
Node(): son{}, cnt(0) {}
};
std::vector<Node> tr;
std::vector<int> root;
Trie(){
tr.emplace_back(Node());
root.emplace_back(0);
}
int newNode() {
tr.emplace_back();
return (int)tr.size() - 1;
}
void add(i64 v) {
int x = root.back(), y = newNode();
root.emplace_back(y);
for (int i = B - 1; ~i; -- i) {
int j = v >> i & 1;
tr[y].son[!j] = tr[x].son[!j];
tr[y].son[j] = newNode();
x = tr[x].son[j], y = tr[y].son[j];
tr[y].cnt = tr[x].cnt + 1;
}
}
i64 max_xor(int x, int y, int v) {
x = root[x], y = root[y];
i64 res = 0;
for (int i = B - 1; ~i; -- i) {
int j = v >> i & 1;
if (tr[tr[y].son[!j]].cnt > tr[tr[x].son[!j]].cnt) {
res |= 1 << i;
j ^= 1;
}
y = tr[y].son[j];
x = tr[x].son[j];
}
return res;
}
i64 max_xor(int x, int y, i64 v, int k) {
x = root[x], y = root[y];
i64 res = 0;
for (int i = B - 1; ~i; -- i) {
int j = v >> i & 1;
if (tr[tr[y].son[!j]].cnt > tr[tr[x].son[!j]].cnt) {
if (k <= tr[tr[y].son[!j]].cnt - tr[tr[x].son[!j]].cnt) {
res |= 1LL << i;
j ^= 1;
}
else
k -= tr[tr[y].son[!j]].cnt - tr[tr[x].son[!j]].cnt;
}
y = tr[y].son[j];
x = tr[x].son[j];
}
return res;
}
};
void solve() {
int n;
std::cin >> n;
Trie tr;
std::vector<int> a(n), pre(n, -1), suf(n, n);
std::vector<int> st;
for (int i = 0; i < n; ++ i) {
std::cin >> a[i];
tr.add(a[i]);
while (st.size() && a[i] > a[st.back()]) {
suf[st.back()] = i;
st.pop_back();
}
if (st.size()) pre[i] = st.back();
st.push_back(i);
}
i64 res = 0;
ST<int, Func> rmq(a);
auto getsuf = [&](int lo, int hi, int v) -> int {
int l = lo;
int res = -1;
while (lo <= hi) {
int x = lo + hi >> 1;
if (rmq.query(l, x) > v) res = x, hi = x - 1;
else lo = x + 1;
}
return res;
};
auto getpre = [&](int lo, int hi, int v) -> int {
int r = hi;
int res = -1;
while (lo <= hi) {
int x = lo + hi >> 1;
if (rmq.query(x, r) > v) res = x, lo = x + 1;
else hi = x - 1;
}
return res;
};
for (int i = 0; i < n; ++ i) {
int l = pre[i], r = suf[i];
int ll = getpre(0, l - 1, a[i]), rr = getsuf(r + 1, n - 1, a[i]);
if (~l)
res = std::max(res, tr.max_xor(~ll ? ll + 1 : 0, r < n ? r : n, a[i]));
if (r < n)
res = std::max(res, tr.max_xor(~l ? l + 1 : 0, ~rr ? rr : n, a[i]));
}
std::cout << res;
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int t = 1;
// std::cin >> t;
while (t--) {
solve();
}
return 0;
}