幸运数字
题目链接:luogu P3292
题目大意
给你一个树,点有点权。
多次询问,每次问你树上的一条路径,它上面所有点任选记得,使得它们点权的异或和最大。
思路
接着我们考虑不用倍增合并。
那我们自然想到这个题,那我们可以用同样的方法把它扩展到树上。
同样是求出 LCA,但我们这次找出两条链就是要找线性基里面深度大于等于
d
e
g
L
C
A
deg_{LCA}
degLCA 的合并。
那复杂度就是 O ( n ( l o g n + l o g G ) + q ( l o g 2 G + l o g n ) ) O(n(logn+logG)+q(log^2G+logn)) O(n(logn+logG)+q(log2G+logn))。
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
struct node {
int to, nxt;
}e[40001];
int n, q, x, y, le[20001], KK;
int tim[20001][61], deg[20001];
int fa[20001][21];
ll val[20001], d[20001][61], ans[61];
void add(int x, int y) {
e[++KK] = (node){y, le[x]}; le[x] = KK;
e[++KK] = (node){x, le[y]}; le[y] = KK;
}
void xxj_add(int x, ll num, int ti) {//让后面优先的线性基插入
for (int i = 60; i >= 0; i--)
if ((num >> i) & 1) {
if (!d[x][i]) {
d[x][i] = num;
tim[x][i] = ti;
}
if (ti > tim[x][i]) {
swap(ti, tim[x][i]);
swap(num, d[x][i]);
}
num ^= d[x][i];
}
}
void dfs(int now, int father) {//dfs 预处理 LCA 倍增要用的以及前缀和求出 now 到根节点路径的线性基
fa[now][0] = father;
deg[now] = deg[father] + 1;
for (int i = 0; i <= 60; i++)
d[now][i] = d[father][i], tim[now][i] = tim[father][i];
xxj_add(now, val[now], deg[now]);
for (int i = le[now]; i; i = e[i].nxt)
if (e[i].to != father)
dfs(e[i].to, now);
}
int LCA(int x, int y) {//LCA
if (deg[y] > deg[x]) swap(x, y);
for (int i = 20; i >= 0; i--)
if (deg[fa[x][i]] >= deg[y])
x = fa[x][i];
if (x == y) return x;
for (int i = 20; i >= 0; i--)
if (fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
void add_ans(ll x) {//往答案线性基中插入数(不需要特殊的构造)
for (int i = 60; i >= 0; i--)
if ((x >> i) & 1) {
if (!ans[i]) {
ans[i] = x;
break;
}
x ^= ans[i];
}
}
void merge(int x, int y, int k) {//合并线性基
for (int i = 0; i <= 60; i++)
if (tim[x][i] >= k)//记得深度不能比 LCA 的浅,否则就不是路径上的
ans[i] = d[x][i];
for (int i = 0; i <= 60; i++)
if (tim[y][i] >= k)//跟上面同理
add_ans(d[y][i]);
}
ll get_max_ans() {//求答案线性基的异或最大值
ll re = 0;
for (int i = 60; i >= 0; i--)
if ((re ^ ans[i]) > re)
re ^= ans[i];
return re;
}
int main() {
scanf("%d %d", &n, &q);
for (int i = 1; i <= n; i++) scanf("%lld", &val[i]);
for (int i = 1; i < n; i++) {
scanf("%d %d", &x, &y);
add(x, y);
}
dfs(1, 0);
for (int i = 1; i <= 20; i++)//倍增
for (int j = 1; j <= n; j++)
fa[j][i] = fa[fa[j][i - 1]][i - 1];
while (q--) {
memset(ans, 0, sizeof(ans));
scanf("%d %d", &x, &y);
merge(x, y, deg[LCA(x, y)]);
printf("%lld\n", get_max_ans());
}
return 0;
}