传送门:https://www.lydsy.com/JudgeOnline/problem.php?id=2588
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
Input
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
Output
M行,表示每个询问的答案。最后一个询问不输出换行符
Sample Input
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
Sample Output
2
8
9
105
7
这个题可以在用倍增法求fa数组的同时将每一个节点加到主席树中去,并且以该点的父节点作为前驱节点,因此主席树表示的就是每个节点到根节点这条链的信息,所以和求树上两点之间的距离就有一些类似。节点u和节点v之前的第k大就等于tree[u]-tree[lca(u,v)]+tree[v]-tree[lca(u,v)]+tree[lca(u,v)]-tree[fa[lca(u,v)]]化简就可以得到第k大就为tree[u]+tree[v]-tree[lca(u,v)]-tree[fa[lca(u,v)]].同时由于dfs是深度优先,所以每次访问的就是一条链,不必把每个节点重新编号,可以在dfs的过程中更新信息。
代码:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5;
const int maxm = 21;
struct node {
int sum, ls, rs;
} tree[maxn * 20];
int a[maxn], root[maxn], deep[maxn], fa[maxn][maxm + 1];
int n, m, tot, siz;
vector<int> e[maxn], ha;
void update(int pre, int &cur, int L, int R, int x) {
cur = ++tot;
tree[cur] = tree[pre];
tree[cur].sum++;
if (L == R) return;
int mid = L + R >> 1;
if (x <= mid)
update(tree[pre].ls, tree[cur].ls, L, mid, x);
else
update(tree[pre].rs, tree[cur].rs, mid + 1, R, x);
}
int query(int a, int b, int c, int d, int L, int R, int k) {
if (L == R) return ha[L - 1];
int temp = tree[tree[a].ls].sum + tree[tree[b].ls].sum -
tree[tree[c].ls].sum - tree[tree[d].ls].sum;
int mid = L + R >> 1;
if (k <= temp)
return query(tree[a].ls, tree[b].ls, tree[c].ls, tree[d].ls, L, mid, k);
else
return query(tree[a].rs, tree[b].rs, tree[c].rs, tree[d].rs, mid + 1, R,
k - temp);
}
void dfs(int u) {
deep[u] = deep[fa[u][0]] + 1;
update(root[fa[u][0]], root[u], 1, siz, a[u]);
for (int i = 1; i <= maxm; i++) {
fa[u][i] = fa[fa[u][i - 1]][i - 1];
if (!fa[u][i]) break;
}
for (int i=0;i<e[u].size();i++) {
int v=e[u][i];
if (v == fa[u][0]) continue;
fa[v][0] = u;
dfs(v);
}
}
int LCA(int u, int v) {
if (deep[u] < deep[v]) swap(u, v);
for (int i = maxm; i >= 0; i--) {
if (deep[fa[u][i]] >= deep[v]) u = fa[u][i];
}
for (int i = maxm; i >= 0; i--) {
if (fa[u][i] != fa[v][i]) {
u = fa[u][i];
v = fa[v][i];
}
}
if (u != v) u = fa[u][0];
return u;
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]), ha.push_back(a[i]);
sort(ha.begin(), ha.end());
ha.erase(unique(ha.begin(), ha.end()), ha.end());
siz = (int)ha.size();
for (int i = 1; i <= n; i++)
a[i] = lower_bound(ha.begin(), ha.end(), a[i]) - ha.begin() + 1;
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1);
int ans = 0;
for (int i = 1; i <= m; i++) {
int u, v, k;
scanf("%d%d%d", &u, &v, &k);
u ^= ans;
int lca = LCA(u, v);
//cout << lca << endl;
ans = query(root[u], root[v], root[lca], root[fa[lca][0]], 1, siz, k);
printf("%d\n", ans);
}
return 0;
}