链接
题目描述
给定一棵 n 个节点的树,每个点有一个权值。有 m 个询问,每次给你 u,v,k,你需要回答 u xor last 和 v 这两个节点间第 k 小的点权。
思路
可以以父亲为历史版本,儿子的现在版本
每次从父亲向儿子更新,那么答案就是
s
u
m
u
x
o
r
l
a
s
t
+
s
u
m
v
−
s
u
m
l
c
a
−
s
u
m
f
a
l
c
a
sum_{u\ xor\ last}+sum_{v} - sum_{lca} - sum_{fa_{lca}}
sumu xor last+sumv−sumlca−sumfalca
很简单但是我调了半天
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define find(x) (lower_bound(p+1,p+q+1,a[x])-p)
using namespace std;
int T[1000005];
int cnt, n, m, t, q;
int p[1000005], a[1000005];
int h[1000005], dep[1000005], log[1000005], fa[1000005][30];
struct node
{
int to, next;
}g[200005];
struct trr
{
int k;
int ls, rs;
}tr[32000006];
void add(int x, int y)
{
g[++t] = (node){y, h[x]}; h[x] = t;
g[++t] = (node){x, h[y]}; h[y] = t;
}
void addx(int &x, int last)
{
x = ++cnt;
tr[x] = tr[last];
}
void build(int &x, int l, int r)
{
x = ++cnt;
tr[x].k = 0;
int mid = (l + r) >> 1;
if(l == r) return;
build(tr[x].ls, l, mid);
build(tr[x].rs, mid + 1, r);
}
void change(int &x, int last, int l, int r, int val)
{
addx(x, last);
tr[x].k++;
int mid = (l + r) >> 1;
if(l == r) return;
if(val <= mid) change(tr[x].ls, tr[last].ls, l, mid, val);
if(val > mid) change(tr[x].rs, tr[last].rs, mid + 1, r, val);
}
void run(int x, int fath)
{
change(T[x], T[fath], 1, q, find(x));
dep[x] = dep[fath] + 1;
fa[x][0] = fath;
for(int j = 1; j <= 18; ++j)
fa[x][j] = fa[fa[x][j - 1]][j - 1];
for(int i = h[x]; i; i = g[i].next)
{
int to = g[i].to;
if(to == fath) continue;
run(to, x);
}//跑树,然后用父亲版本更新
}
int LCA(int x, int y)
{
if(dep[x] < dep[y]) swap(x, y);
for(int i = 18; i >= 0; --i) {
if (dep[fa[x][i]] >= dep[y]) x = fa[x][i];
}
if (x == y) return x;
for(int i = 18; i >= 0; --i) {
if(fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
}
return fa[x][0];
}
int ask(int u, int v, int lc, int flc, int l, int r, int pl)
{
int sum = tr[tr[u].ls].k + tr[tr[v].ls].k - tr[tr[lc].ls].k - tr[tr[flc].ls].k;
int mid = (l + r) >> 1;
if(l == r) return l;
if(sum >= pl) return ask(tr[u].ls, tr[v].ls, tr[lc].ls, tr[flc].ls, l, mid, pl);
if(pl > sum) return ask(tr[u].rs, tr[v].rs, tr[lc].rs, tr[flc].rs, mid + 1, r, pl - sum);
}
int main()
{
log[0] = -1;
for(int i = 1; i <= 1000000; ++i)
log[i] = log[i >> 1] + 1;
scanf("%d%d", &n ,&m);
for(int i = 1; i <= n; ++i)
scanf("%d", &a[i]), p[i] = a[i];
sort(p + 1, p + n + 1);
q = unique(p + 1, p + n + 1) - p - 1;
for(int i = 1; i < n; ++i)
{
int u, v;
scanf("%d%d", &u, &v);
add(u, v);
}
build(T[0], 1, q);
run(1, 0);
int last = 0;
for(int i = 1; i <= m; ++i)
{
int u, v, k;
scanf("%d%d%d", &u, &v, &k);
u ^= last;
int lca = LCA(u, v);
int flca = fa[lca][0];
last = ask(T[u], T[v], T[lca], T[flca], 1, q, k);
printf("%d\n", p[last]);
last = p[last];
}
return 0;
}