普通01字典树带删除, 求最大/小异或值;
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const ll N = 1e5 + 5;
ll n, m;
ll a[N], b[N];
ll ch[32 * N][2];
ll val[32 * N];
ll sz;
void init()
{
sz = 0;
memset(ch, 0, sizeof ch);
}
void insert(ll num)
{
ll u = 0;
for(ll i = 32; i >= 0; i --)
{
ll c = ((num >> i) & 1);
if(!ch[u][c])
{
ch[u][c] = ++ sz;
}
u = ch[u][c];
val[u] ++;
}
}
void remove(ll num)
{
ll u = 0;
for(ll i = 32; i >= 0; i--)
{
ll c = ((num >> i) & 1);
ll pre = u;
u = ch[u][c];
if(--val[u] == 0) ch[pre][c] = 0; // 断掉链接
}
}
ll query(ll num, bool flag) // flag 查询最大/最小异或true 最大, false 最小
{
ll p = 0, v = 0;
ll ans = 0;
for(ll i = 32; i >= 0; i--)
{
ll c = ((num >> i) & 1);
if(flag == true)
{
if(ch[p][c ^ 1])
{
p = ch[p][c ^ 1];
ans += 1 << i;
}
else if(ch[p][c])
{
p = ch[p][c];
}
else
break;
}
else
{
if(ch[p][c])
{
p = ch[p][c];
}
else if(ch[p][c ^ 1])
p = ch[p][c ^ 1], ans += 1 << i;
else
break;
}
}
return ans;
}
int main()
{
init();
insert(10);
insert(7);
cout << query(1, true) << endl;
cout << query(2, false) << endl;
return 0;
}
可持续01字典树
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const ll N = 260005;
int rt[N], b[N];
struct node
{
int cnt;
int ch[N * 32][2], sum[N * 32];
void init()
{
cnt = 0, b[0] = 1; // b[i], 记录第2^i
for(int i = 1; i <= 30; i ++) b[i] = b[i - 1] * 2;
memset(ch, 0, sizeof ch);
memset(sum, 0, sizeof sum);
}
int insert_(int x, int val) // x为上一个trie树的根节点
{
int res, y;
y = res = ++ cnt; // 开新点
for(int i = 30; i >= 0; i --)
{
ch[y][0] = ch[x][0];
ch[y][1] = ch[x][1];
sum[y] = sum[x] + 1;
int tmp = val & b[i];
tmp >>= i; // 求出val在二进制下第i位的值
x = ch[x][tmp]; // x(tmp方向)向下走
ch[y][tmp] = ++ cnt; // 开点
y = ch[y][tmp]; // y也继续向下走
}
sum[y] = sum[x] + 1; // 到叶子节点后, 表示在同一个位置来过的次数
return res; // 返回x建的01字典树的根节点
}
int query(int l, int r, int val) // 返回与val异或最大的那个数字
{
int ans = 0;
for(int i = 30; i >= 0; i --)
{
int tmp = val&b[i];
tmp >>= i; // 取val的第i位(二进制)
if(sum[ch[r][tmp ^ 1]] - sum[ch[l][tmp ^ 1]])
{
l = ch[l][tmp ^ 1], r = ch[r][tmp ^ 1];
if(r != 0 && tmp == 0) ans += b[i];
}
else
{
l = ch[l][tmp], r = ch[r][tmp];
if(r != 0 && tmp == 1) ans += b[i];
}
}
return ans;
}
}trie;
int main()
{
int a[120];
trie.init(); //必须初始化
for(int i = 1; i <= 110; i ++) a[i] = i;
rt[0] = 0;
for(int i = 1; i <= 110; i ++)
rt[i] = trie.insert_(rt[i - 1], a[i]);
int l = 2, r = 75;
int x = 0;
cout << trie.query(rt[l - 1], rt[r], x); // 查询a[l] ~ a[r]中一个数字与x异或的最大值
// 此处必须要用rt[l - 1], rt[r];
return 0;
}