题意:
题目链接:https://nanti.jisuanke.com/t/17120
给出一棵树,每个节点的有个价值,给q个询问(u,v,k),若结点u到v的路径上的点为a0,a1,a2…am,问其中a0^ak^a2k^…apk的结果。
思路:
套路,按照k分类,首先求出LCA,k大于等于250时直接暴力求解,k小于250时保存从根结点到每个结点的异或和sum[u][x],x表示每次走x步。
代码:
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 5e4 + 10;
const int NX = MAXN * 2;
const int M = 16;
const int K = 250;
int a[MAXN];
int fa[MAXN][M];
int sum[MAXN][K];
vector <int> tree[MAXN];
int dep[MAXN], ct;
int sta[MAXN], top;
void dfs(int u, int f) {
fa[u][0] = f;
for (int i = 1; i < M; i++) fa[u][i] = fa[fa[u][i - 1]][i - 1];
sta[top++] = u; dep[u] = top;
for (int i = 1; i < K; i++) {
sum[u][i] = a[u];
if (i < top) sum[u][i] ^= sum[sta[top - i - 1]][i];
}
for (int i = 0; i < (int)tree[u].size(); i++) {
int v = tree[u][i];
if (v == f)continue;
dfs(v, u);
}
--top;
}
int lca(int x, int y) {
if (dep[x] < dep[y])swap(x, y);
int dx = dep[x] - dep[y];
for (int i = 0; i < M && dx > 0; i++) {
if (dx & 1) x = fa[x][i];
dx >>= 1;
}
if (x == y)return x;
for (int i = M - 1; i >= 0; i--) {
if (fa[x][i] != fa[y][i]) {
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
void fuck(int& u, int dx) {
for (int i = 0; i < M && dx>0; i++) {
if (dx & 1) u = fa[u][i];
dx >>= 1;
}
}
int n, q;
int main() {
//freopen("in.txt", "r", stdin);
while (scanf("%d%d", &n, &q) == 2) {
for (int i = 1; i <= n; i++) tree[i].clear();
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
tree[u].push_back(v);
tree[v].push_back(u);
}
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
dfs(1, 0);
while (q--) {
int x, y, k;
scanf("%d%d%d", &x, &y, &k);
int f = lca(x, y);
int ans = 0;
if (k >= K) {
int dx = (dep[x] - dep[f]) % k;
int dy = (dep[y] - dep[f]) % k;
int d = (dx + dy) % k;
while (dep[x] > dep[f]) {
ans ^= a[x];
fuck(x, k);
}
fuck(y, d);
while (dep[y] >= dep[f]) {
ans ^= a[y];
fuck(y, k);
}
}
else {
int dx = (dep[x] - dep[f]);
int dy = (dep[y] - dep[f]);
int d = (dx + dy) % k;
fuck(y, d);
if (dx > 0) {
int ddx = (dx-1) / k + 1;
ans ^= sum[x][k];
fuck(x, ddx * k);
ans ^= sum[x][k];
}
if (dep[y] >= dep[f]) {
int ddy = (dep[y] - dep[f]) / k + 1;
ans ^= sum[y][k];
fuck(y, ddy * k);
ans ^= sum[y][k];
}
}
printf("%d\n", ans);
}
}
return 0;
}